diff --git a/changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml b/changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml new file mode 100644 index 000000000..d62356e20 --- /dev/null +++ b/changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml @@ -0,0 +1,2 @@ +changed: + - Update `llzk-pod-to-scalar` to split pods within arrays by creating multiple arrays. diff --git a/include/llzk/Dialect/POD/Transforms/TransformationPasses.td b/include/llzk/Dialect/POD/Transforms/TransformationPasses.td index 971d3bc93..65f93224d 100644 --- a/include/llzk/Dialect/POD/Transforms/TransformationPasses.td +++ b/include/llzk/Dialect/POD/Transforms/TransformationPasses.td @@ -15,7 +15,15 @@ include "llzk/Pass/PassBase.td" def PodToScalarPass : LLZKPass<"llzk-pod-to-scalar"> { let summary = "Replace PODs with scalar values"; let description = [{ - Replace `pod.type` values with the proper number of scalar values + Scalarize `pod.type` values by splitting POD-typed struct members into + multiple scalar members, splitting POD-typed array elements into parallel + arrays, then rewriting affected member accesses plus function signatures, + calls, and returns, and finally running POD-specific SROA + mem2reg cleanup + so the remaining POD storage is promoted to SSA values. + + If it is necessary to scalarize both PODs and arrays, run this pass before + running the `-llzk-array-to-scalar` pass because that pass will not scalarize + array types that are within a POD type. }]; } diff --git a/include/llzk/Util/TypeHelper.h b/include/llzk/Util/TypeHelper.h index eebd173ab..d76e58259 100644 --- a/include/llzk/Util/TypeHelper.h +++ b/include/llzk/Util/TypeHelper.h @@ -171,6 +171,14 @@ namespace llzk { bool isDynamic(mlir::IntegerAttr intAttr); +/// Flatten any array-valued element type into the dimensions of `outerArrTy`. +/// +/// This is used when an LLZK array logically resolves to a higher-rank array even though array +/// element types cannot themselves be arrays. The returned type keeps `outerArrTy`'s leading +/// dimensions, appends any nested dimensions from `elementType`, and uses the innermost non-array +/// element type as the final element type. +array::ArrayType flattenArrayElementType(array::ArrayType outerArrTy, mlir::Type elementType); + /// Compute the cardinality (i.e. number of scalar constraints) for an EmitEqualityOp type since the /// op can be used to constrain two same-size arrays. uint64_t computeEmitEqCardinality(mlir::Type type); diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 4e3778609..980f69691 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -12,44 +12,47 @@ /// /// The steps of this transformation are as follows: /// -/// 0. Scan to find `llzk.nondet` ops that allocate uninitialized pods and replace them with -/// an equivalent `pod.new` +/// 0. Rewrite pod-typed `llzk.nondet` allocations into `pod.new` so later stages only need to +/// reason about POD storage through POD dialect operations. /// -/// 1. Run a dialect conversion that replaces `PodType` struct members with one scalar member per -/// record and remembers how each original member was split. +/// 1. Run a dialect conversion that replaces pod-typed struct members with one scalar member per +/// POD record, replaces array-typed struct members whose element type is a POD with one parallel +/// array member per POD record, and remembers how each original member was split for the later +/// rewriting steps. /// -/// 2. Run a dialect conversion that does the following: +/// 2. Run a dialect conversion that splits arrays whose element type is a POD into parallel arrays +/// in `llzk.nondet`, `array.*`, `constrain.eq`, `constrain.in`, `struct.readm`, `struct.writem`, +/// `function.def`, `function.call`, `function.return`, and bool quantifiers. /// -/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1 -/// so they instead perform scalar reads and writes from the new members. The transformation is -/// local to the current op. Therefore, when replacing the `MemberReadOp` a new pod is -/// created locally and all uses of the `MemberReadOp` are replaced with the new pod Value, -/// then each scalar member read is followed by scalar write into the new pod. Similarly, -/// when replacing a `MemberWriteOp`, each element in the pod operand needs a scalar read -/// from the pod followed by a scalar write to the new member. Making only local changes -/// keeps this step simple and later steps will optimize. +/// 3. Run a dialect conversion that does the following: +/// +/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the pod-typed struct members split in +/// step 1 so they instead perform reads and writes on the new scalar members. Reads and writes +/// are tracked through virtual POD placeholders so the conversion can keep propagating scalar +/// leaves instead of re-introducing aggregate POD storage. /// /// - Remove optional initialization from `NewPodOp` and instead insert a list of `WritePodOp` /// immediately following. /// -/// - Split pods to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary -/// create/read/write ops so the changes are as local as possible (just as described for -/// `MemberReadOp` and `MemberWriteOp`) +/// - Split remaining direct POD values to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp`. +/// When a rewritten op still needs POD contents locally, keep them in the same virtual +/// placeholder form for as long as possible and only materialize concrete `pod.write` +/// operations as a fallback for unresolved uses. /// -/// 3. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the +/// 4. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the /// access can be modeled as an SSA value flowing through the region boundary. This puts the /// pod accesses that mem2reg must eliminate into a parent block or loop-carried value. /// -/// 4. Run MLIR "sroa" pass to split each pod with `N` records into `N` pods with 1 record each +/// 5. Run MLIR "sroa" pass to split remaining POD allocations into single-record POD allocations /// (to prepare for the "mem2reg" pass because its API cannot split memory by itself). /// -/// 5. Run MLIR "mem2reg" pass to convert all single-record pod allocations and accesses into SSA +/// 6. Run MLIR "mem2reg" pass to convert all single-record POD allocations and accesses into SSA /// values. /// -/// 6. Remove pod allocations that become unread after memory promotion, then remove SSA values +/// 7. Remove POD allocations that become unread after memory promotion, then remove SSA values /// made dead by that cleanup. /// -/// ** Steps 4-6 are rerun while nested POD types are still being exposed, until a fixpoint. +/// Steps 5-7 are rerun while nested POD types are still being exposed, until a fixpoint. /// /// Note: This transformation imposes a "last write wins" semantics on pod records. If /// different/configurable semantics are added in the future, some additional transformation would @@ -62,9 +65,13 @@ //===----------------------------------------------------------------------===// #include "llzk/Dialect/Array/IR/Dialect.h" +#include "llzk/Dialect/Array/IR/Ops.h" +#include "llzk/Dialect/Array/IR/Types.h" #include "llzk/Dialect/Bool/IR/Dialect.h" +#include "llzk/Dialect/Bool/IR/Ops.h" #include "llzk/Dialect/Cast/IR/Dialect.h" #include "llzk/Dialect/Constrain/IR/Dialect.h" +#include "llzk/Dialect/Constrain/IR/Ops.h" #include "llzk/Dialect/Felt/IR/Dialect.h" #include "llzk/Dialect/Function/IR/Dialect.h" #include "llzk/Dialect/Function/IR/Ops.h" @@ -76,6 +83,7 @@ #include "llzk/Dialect/POD/IR/Types.h" #include "llzk/Dialect/POD/Transforms/TransformationPasses.h" #include "llzk/Dialect/Polymorphic/IR/Dialect.h" +#include "llzk/Dialect/Polymorphic/IR/Ops.h" #include "llzk/Dialect/RAM/IR/Dialect.h" #include "llzk/Dialect/String/IR/Dialect.h" #include "llzk/Dialect/Struct/IR/Ops.h" @@ -83,9 +91,11 @@ #include "llzk/Transforms/LLZKTransformationPasses.h" #include "llzk/Transforms/SpecializedMemoryPasses.h" #include "llzk/Util/Concepts.h" +#include "llzk/Util/TypeHelper.h" #include "llzk/Util/Walk.h" #include +#include #include #include #include @@ -95,6 +105,9 @@ #include #include +#include +#include + // Include the generated base pass class definitions. namespace llzk::pod { #define GEN_PASS_DEF_PODTOSCALARPASS @@ -103,9 +116,11 @@ namespace llzk::pod { using namespace mlir; using namespace llzk; +using namespace llzk::array; using namespace llzk::pod; using namespace llzk::function; using namespace llzk::component; +using namespace llzk::polymorphic; #define DEBUG_TYPE "llzk-pod-to-scalar" @@ -144,6 +159,64 @@ template <> struct DenseMapInfo { namespace { +/// Return whether the given read/write access targets the same POD record. +inline static bool isSamePodRecord(ReadPodOp readOp, Value podRef, StringAttr recordName) { + return readOp.getPodRef() == podRef && readOp.getRecordNameAttr() == recordName; +} + +/// Return whether the given read/write access targets the same POD record. +inline static bool isSamePodRecord(WritePodOp writeOp, Value podRef, StringAttr recordName) { + return writeOp.getPodRef() == podRef && writeOp.getRecordNameAttr() == recordName; +} + +/// Return whether `op` contains a nested write to `podRef.recordName`. +static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recordName) { + return walkContainsMatch(op, [&](WritePodOp writeOp) { + return writeOp.getOperation() != &op && isSamePodRecord(writeOp, podRef, recordName); + }); +} + +/// Return whether `op` contains a nested write to any record of `podRef`. +static bool hasNestedWriteToPod(Operation &op, Value podRef) { + return walkContainsMatch(op, [&](WritePodOp writeOp) { + return writeOp.getOperation() != &op && writeOp.getPodRef() == podRef; + }); +} + +/// Return whether `op` contains any read from `podRef.recordName`. +static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) { + return walkContainsMatch(op, [&podRef, &recordName](ReadPodOp readOp) { + return isSamePodRecord(readOp, podRef, recordName); + }); +} + +/// Return whether `op` or any nested operation uses `value` as an operand. +static bool hasValueUse(Operation &op, Value value) { + return walkContainsMatch(op, [&value](Operation *nestedOp) { + return llvm::is_contained(nestedOp->getOperands(), value); + }); +} + +/// Return the nearest preceding same-record write that can be forwarded to `readOp`. +/// +/// This fold is intentionally conservative: it only forwards through intervening operations that do +/// not use the POD value at all. That keeps the rewrite local and avoids reasoning about other +/// whole-POD uses or record accesses that may observe mutation ordering. +static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp) { + Value podRef = readOp.getPodRef(); + StringAttr recordName = readOp.getRecordNameAttr(); + + for (Operation *op = readOp->getPrevNode(); op; op = op->getPrevNode()) { + if (!hasValueUse(*op, podRef)) { + continue; + } + + auto writeOp = dyn_cast(op); + return writeOp && isSamePodRecord(writeOp, podRef, recordName) ? writeOp : nullptr; + } + return nullptr; +} + /// If the given PodType can be split into scalars (always true for PodType), return it. inline static PodType splittablePod(PodType pt) { return pt; } @@ -178,15 +251,78 @@ template static bool containsSplittablePodType(ValueTypeRange ty return false; } +/// If the input ArrayType has a POD element type, return the input, else nullptr. +inline static ArrayType splittablePodArray(ArrayType at) { + return isa(at.getElementType()) ? at : nullptr; +} + +/// If the input Type is an ArrayType with a POD element type, return the input, else nullptr. +inline static ArrayType splittablePodArray(Type t) { + if (ArrayType at = dyn_cast(t)) { + return splittablePodArray(at); + } + return nullptr; +} + +/// Return the flattened leaf type addressed by `recordChain` within `type`. +static Type getFlattenedTypeAlongPath(Type type, ArrayRef recordChain) { + if (recordChain.empty()) { + return type; + } + + if (PodType podTy = dyn_cast(type)) { + Type nextType = podTy.getRecordMap().lookup(recordChain.front().getValue()); + assert(nextType && "record path must exist in the containing POD"); + return getFlattenedTypeAlongPath(nextType, recordChain.drop_front()); + } + + if (ArrayType arrTy = splittablePodArray(type)) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + Type nextType = elemPodTy.getRecordMap().lookup(recordChain.front().getValue()); + assert(nextType && "record path must exist in the POD array element type"); + return flattenArrayElementType( + arrTy, getFlattenedTypeAlongPath(nextType, recordChain.drop_front()) + ); + } + + llvm_unreachable("record path cannot continue through a non-POD leaf"); +} + +/// Visit each non-POD leaf record in `podTy`, providing its record-name chain and leaf type. +template +static void forEachPodLeaf(PodType podTy, SmallVectorImpl &recordChain, Fn &&callback) { + std::function walk = [&](Type type) { + if (PodType nestedPodTy = llvm::dyn_cast(type)) { + for (RecordAttr record : nestedPodTy.getRecords()) { + recordChain.push_back(record.getName()); + walk(record.getType()); + recordChain.pop_back(); + } + } else if (ArrayType arrTy = splittablePodArray(type)) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + for (RecordAttr record : elemPodTy.getRecords()) { + recordChain.push_back(record.getName()); + walk(flattenArrayElementType(arrTy, record.getType())); + recordChain.pop_back(); + } + } else { + callback(RecordChain(recordChain), type); + } + }; + + walk(podTy); +} + /// If the given Type is a PodType that can be split into scalars, append `collect` with all of /// the scalar types that result from splitting the PodType. Otherwise, just push the `Type`. size_t splitPodTypeTo(Type t, SmallVector &collect) { if (PodType pt = splittablePod(t)) { - auto records = pt.getRecords(); - for (RecordAttr record : records) { - collect.push_back(record.getType()); - } - return records.size(); + SmallVector recordChain; + size_t originalSize = collect.size(); + forEachPodLeaf(pt, recordChain, [&collect](const RecordChain &, Type leafType) { + collect.push_back(leafType); + }); + return collect.size() - originalSize; } else { collect.push_back(t); return 1; @@ -216,202 +352,2284 @@ splitPodType(TypeCollection types, SmallVector *originalIdxToSize = null return collect; } -/// Create a `pod.read` for one record of `podRef`. -inline static ReadPodOp -genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &rewriter) { - Type resultType = - llvm::cast(podRef.getType()).getRecordMap().lookup(recordName.getValue()); - return rewriter.create(loc, resultType, podRef, recordName); +/// Return `true` iff any type in the range is an array whose element type is a POD. +inline static bool containsSplittablePodArrayType(ArrayRef types) { + return llvm::any_of(types, [](Type t) { return splittablePodArray(t); }); } -/// Create a `pod.write` for one record of `podRef`. -inline static WritePodOp -genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuilder &rewriter) { - return rewriter.create(loc, podRef, recordName, value); +/// Return `true` iff any type in the range is an array whose element type is a POD. +template static bool containsSplittablePodArrayType(ValueTypeRange types) { + return llvm::any_of(types, [](Type t) { return splittablePodArray(t); }); } -/// Return the suffixes to append to a function arg/result name when splitting the given type. -static SmallVector getSplitRecordNameSuffixes(Type type) { - SmallVector suffixes; - if (PodType pt = splittablePod(type)) { - suffixes.reserve(pt.getRecords().size()); - for (RecordAttr record : pt.getRecords()) { - StringRef name = record.getName().getValue(); - std::string result; - result.reserve(name.size() + 1); - result.push_back('.'); - result.append(name.data(), name.size()); - suffixes.push_back(result); +/// If `t` is an array with POD element type, append one parallel array type for each POD leaf. +static size_t splitPodArrayTypeTo( + Type t, SmallVectorImpl &collect, SmallVector *splitIds = nullptr +) { + if (ArrayType at = splittablePodArray(t)) { + auto podTy = llvm::cast(at.getElementType()); + SmallVector recordChain; + size_t originalSize = collect.size(); + forEachPodLeaf(podTy, recordChain, [&](RecordChain id, Type leafType) { + collect.push_back(flattenArrayElementType(at, leafType)); + if (splitIds) { + splitIds->push_back(std::move(id)); + } + }); + return collect.size() - originalSize; + } + + collect.push_back(t); + return 1; +} + +/// Return the index-array carrier type used to preserve the shape of a zero-leaf array-of-POD. +static ArrayType getZeroLeafPodArrayShapeCarrierType(ArrayType arrTy) { + return arrTy.cloneWith(IndexType::get(arrTy.getContext())); +} + +/// Return `true` iff splitting `arrTy` produces no concrete POD leaf arrays. +static bool hasZeroLeafPodArraySplit(ArrayType arrTy) { + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes); + return splitTypes.empty(); +} + +/// Convert one type using the step-2 array-of-POD lowering convention. +/// +/// Most arrays-of-POD expand to one parallel array per POD leaf. When the element POD has no +/// leaves, keep a single index-array carrier so later rewrites can still preserve shape and affine +/// instantiation information. +static size_t convertPodArrayTypeTo(Type t, SmallVectorImpl &collect) { + if (ArrayType arrTy = splittablePodArray(t)) { + size_t oldSize = collect.size(); + splitPodArrayTypeTo(arrTy, collect); + if (collect.size() == oldSize) { + collect.push_back(getZeroLeafPodArrayShapeCarrierType(arrTy)); } + return collect.size() - oldSize; } - return suffixes; + + collect.push_back(t); + return 1; } -// If the operand has PodType, add reads from all pod records to the `newOperands` list otherwise -// add the original operand to the list. -static void processInputOperand( - Location loc, Value operand, SmallVector &newOperands, - ConversionPatternRewriter &rewriter +/// For each Type in the given input collection, call `convertPodArrayTypeTo(Type,...)`. +template +inline void convertPodArrayTypesTo( + TypeCollection types, SmallVectorImpl &collect, + SmallVector *originalIdxToSize = nullptr ) { - if (PodType pt = splittablePod(operand.getType())) { - for (RecordAttr record : pt.getRecords()) { - newOperands.push_back(genRead(loc, operand, record.getName(), rewriter)); + if (originalIdxToSize) { + originalIdxToSize->reserve(types.size()); + } + for (Type t : types) { + size_t count = convertPodArrayTypeTo(t, collect); + if (originalIdxToSize) { + originalIdxToSize->push_back(count); } - } else { - newOperands.push_back(operand); } } -/// For each operand with PodType, add reads from all pod records in place of the original operand -/// and update the op to use the new operands. -static void processInputOperands( - ValueRange operands, MutableOperandRange outputOpRef, Operation *op, - ConversionPatternRewriter &rewriter +/// Return the step-2 converted types for the given collection. +template +static SmallVector +convertPodArrayTypes(TypeCollection types, SmallVector *originalIdxToSize = nullptr) { + SmallVector collect; + convertPodArrayTypesTo(types, collect, originalIdxToSize); + return collect; +} + +/// For each Type in the given input collection, call `splitPodArrayTypeTo(Type,...)`. +template +inline void splitPodArrayTypeTo( + TypeCollection types, SmallVectorImpl &collect, SmallVector *originalIdxToSize ) { - SmallVector newOperands; - for (Value v : operands) { - processInputOperand(op->getLoc(), v, newOperands, rewriter); + for (Type t : types) { + size_t count = splitPodArrayTypeTo(t, collect); + if (originalIdxToSize) { + originalIdxToSize->push_back(count); + } } - rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() { - outputOpRef.assign(ValueRange(newOperands)); - }); } -/// Register the dialects and operations that remain legal across the conversion-based stages. -inline static void baseTargetSetup(ConversionTarget &target) { - target.addLegalDialect< - LLZKDialect, array::ArrayDialect, boolean::BoolDialect, cast::CastDialect, - constrain::ConstrainDialect, component::StructDialect, felt::FeltDialect, - function::FunctionDialect, global::GlobalDialect, include::IncludeDialect, pod::PODDialect, - polymorphic::PolymorphicDialect, ram::RAMDialect, string::StringDialect, arith::ArithDialect, - scf::SCFDialect>(); - target.addLegalOp(); +/// Return a list such that each non-array POD type is kept as-is, while each array-of-POD type is +/// replaced by one parallel array type per non-POD leaf record in the element POD. +template +inline SmallVector +splitPodArrayType(TypeCollection types, SmallVector *originalIdxToSize = nullptr) { + SmallVector collect; + splitPodArrayTypeTo(types, collect, originalIdxToSize); + return collect; } -/// Rewrite pod-typed `llzk.nondet` allocations into explicit `pod.new` allocations so the rest of -/// the pass only needs to reason about POD storage through POD dialect operations. -class NondetToNewPod : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - NonDetOp nondetOp, OpAdaptor, ConversionPatternRewriter &rewriter - ) const override { - if (auto pt = dyn_cast(nondetOp.getType())) { - rewriter.replaceOpWithNewOp(nondetOp, pt); - return success(); +/// Return the suffixes to append to a function arg/result name when splitting an array of PODs. +static SmallVector getSplitPodArrayRecordNameSuffixes(Type type) { + SmallVector suffixes; + if (ArrayType at = splittablePodArray(type)) { + SmallVector splitIds; + SmallVector ignoredTypes; + splitPodArrayTypeTo(at, ignoredTypes, &splitIds); + suffixes.reserve(splitIds.size()); + for (const RecordChain &id : splitIds) { + std::string suffix; + llvm::raw_string_ostream os(suffix); + for (StringAttr recordName : id.nameList) { + os << '.' << recordName.getValue(); + } + suffixes.push_back(std::move(suffix)); } - return failure(); } -}; + return suffixes; +} -/// Prepare the module by replacing `llzk.nondet` pod allocation ops with `pod.new`. -static LogicalResult step0(ModuleOp modOp) { - MLIRContext *ctx = modOp.getContext(); - RewritePatternSet patterns {ctx}; - patterns.add(ctx); - ConversionTarget target {*ctx}; +/// Insert a `poly.unifiable_cast` when a rewritten value must match a more specific type. +/// +/// This is the common bridge between wildcard-backed storage values and the more precise types +/// expected by surrounding rewritten IR. The cast is only emitted when the source and target +/// types unify and differ syntactically. +static Value castValueToTypeIfNeeded(OpBuilder &bldr, Location loc, Value value, Type targetType) { + if (value.getType() == targetType) { + return value; + } + assert(typesUnify(value.getType(), targetType) && "expected compatible rewritten types"); + return bldr.create(loc, targetType, value); +} - baseTargetSetup(target); - target.addDynamicallyLegalOp([](NonDetOp op) { return !isa(op.getType()); }); +/// Create a `pod.read` for one record of `podRef`. +inline static ReadPodOp +genRead(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName) { + Type resultType = + llvm::cast(podRef.getType()).getRecordMap().lookup(recordName.getValue()); + return bldr.create(loc, resultType, podRef, recordName); +} - return applyFullConversion(modOp, target, std::move(patterns)); +/// Create a `pod.write` for one record of `podRef`. +inline static WritePodOp +genWrite(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName, Value value) { + Type recordType = + llvm::cast(podRef.getType()).getRecordMap().lookup(recordName.getValue()); + return bldr.create( + loc, podRef, recordName, castValueToTypeIfNeeded(bldr, loc, value, recordType) + ); } -/// new member name and type -using MemberInfo = std::pair; -/// original nested pod record name chain -> split scalar member info -using LocalMemberReplacementMap = DenseMap; -/// struct -> original pod-type member name -> LocalMemberReplacementMap -using MemberReplacementMap = DenseMap>; +/// Return the single converted value from a 1:N adaptor range. +inline static Value getSingleConvertedValue(ValueRange values) { + assert(values.size() == 1 && "expected a 1:1 converted value range"); + return values.front(); +} -/// Build a flattened struct-member name like `member_outer_inner_leaf`. -static StringAttr -getFlattenedMemberName(MLIRContext *ctx, StringAttr memberName, ArrayRef recordChain) { - std::string flatName; - llvm::raw_string_ostream os(flatName); - os << memberName.getValue(); - for (StringAttr recordName : recordChain) { - os << '_' << recordName.getValue(); +/// Store the affine-map operand groups needed to rebuild one concrete array instantiation. +/// +/// The layout mirrors `array.new`: `mapOperandStorage` keeps each instantiation group separately, +/// and `numDimsPerMap` records how many values in each group are dimensional arguments. +struct ArrayInstantiationInfo { + SmallVector> mapOperandStorage; + SmallVector numDimsPerMap; +}; + +/// Try to recover affine-map instantiation operands from a concrete array-producing value. +/// +/// This peels compatibility casts, follows simple `pod.read` to dominating `pod.write` +/// forwarding, and succeeds only when the value ultimately traces back to a concrete +/// `array.new` carrying the instantiation groups. +static std::optional tryGetArrayInstantiationInfo(Value value) { + while (auto cast = value.getDefiningOp()) { + value = cast.getInput(); } - return StringAttr::get(ctx, flatName); -} -/// Recursively create scalar leaf members for a POD-typed struct member. -static void flattenPodMemberIntoLeaves( - MemberDefOp originalMember, PodType podTy, SmallVectorImpl &recordChain, - LocalMemberReplacementMap &localRepMapRef, SymbolTable &structSymbolTable, - ConversionPatternRewriter &rewriter -) { - for (RecordAttr record : podTy.getRecords()) { - recordChain.push_back(record.getName()); - if (PodType nestedPodTy = dyn_cast(record.getType())) { - flattenPodMemberIntoLeaves( - originalMember, nestedPodTy, recordChain, localRepMapRef, structSymbolTable, rewriter - ); - recordChain.pop_back(); - continue; + if (ReadPodOp read = value.getDefiningOp()) { + if (WritePodOp write = findNearestForwardableWriteInBlock(read)) { + return tryGetArrayInstantiationInfo(write.getValue()); } + return std::nullopt; + } - StringAttr name = getFlattenedMemberName( - originalMember.getContext(), originalMember.getSymNameAttr(), recordChain - ); - Type ty = record.getType(); - MemberDefOp newMember = rewriter.create( - originalMember.getLoc(), name, ty, originalMember.getSignal(), originalMember.getColumn() - ); - newMember.setPublicAttr(originalMember.hasPublicAttr()); - localRepMapRef[RecordChain(recordChain)] = - std::make_pair(structSymbolTable.insert(newMember), ty); - recordChain.pop_back(); + auto create = value.getDefiningOp(); + if (!create) { + return std::nullopt; + } + + ArrayInstantiationInfo info; + info.mapOperandStorage.reserve(create.getMapOperands().size()); + for (OperandRange group : create.getMapOperands()) { + info.mapOperandStorage.emplace_back(group.begin(), group.end()); + } + + if (DenseI32ArrayAttr numDimsPerMap = create.getNumDimsPerMapAttr()) { + llvm::append_range(info.numDimsPerMap, numDimsPerMap.asArrayRef()); } + + return info; } -/// Split a pod-typed struct member definition into one scalar member definition per POD record. +/// Materialize a scalar array value that preserves the shape of `originalArrTy`. /// -/// The replacement map records the fresh member symbols so later rewrites can retarget -/// `struct.readm` and `struct.writem` operations to the split members. -class SplitPodInMemberDefOp : public OpConversionPattern { - SymbolTableCollection &tables; - MemberReplacementMap &repMapRef; +/// This is used as a shape-only carrier for `array.len` when an array-of-POD splits to +/// zero parallel leaf arrays (for example, `!array.type<... x !pod.type<[]>>`). +static Value materializeArrayLengthCarrier( + Value originalArrRef, ArrayType originalArrTy, Location loc, ConversionPatternRewriter &rewriter +) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(originalArrTy); -public: - SplitPodInMemberDefOp( - MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap - ) - : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap) {} + if (auto create = originalArrRef.getDefiningOp()) { + if (create.getMapOperands().empty()) { + return rewriter.create(loc, carrierTy); + } - inline static bool legal(MemberDefOp op) { return !splittablePod(op.getType()); } + SmallVector mapOperands; + mapOperands.reserve(create.getMapOperands().size()); + for (OperandRange mapOperandGroup : create.getMapOperands()) { + mapOperands.push_back(mapOperandGroup); + } + return rewriter.create( + loc, carrierTy, mapOperands, create.getNumDimsPerMapAttr() + ); + } - LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); } + if (std::optional instantiation = + tryGetArrayInstantiationInfo(originalArrRef)) { + if (instantiation->mapOperandStorage.empty()) { + return rewriter.create(loc, carrierTy); + } - void - rewrite(MemberDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - StructDefOp inStruct = op->getParentOfType(); - assert(inStruct); - LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()]; + SmallVector mapOperands; + mapOperands.reserve(instantiation->mapOperandStorage.size()); + for (const SmallVector &group : instantiation->mapOperandStorage) { + mapOperands.push_back(group); + } + return rewriter.create( + loc, carrierTy, mapOperands, ArrayRef(instantiation->numDimsPerMap) + ); + } - PodType podTy = llvm::cast(adaptor.getType()); // safe per legal() check + bool hasAffineDims = llvm::any_of(originalArrTy.getDimensionSizes(), [](Attribute dimSize) { + return llvm::isa(dimSize); + }); + if (!hasAffineDims) { + return rewriter.create(loc, carrierTy); + } - SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct); + return rewriter.create(loc, carrierTy); +} + +/// Flatten a range of converted value ranges into a single list of values. +template +static SmallVector flattenConvertedValues(RangeOfRanges ranges) { + SmallVector values; + for (ValueRange range : ranges) { + llvm::append_range(values, range); + } + return values; +} + +/// Return `true` iff the inputs are the same size and each value type in `values` unifies +/// with the corresponding `types` entry. +template +inline static bool allValueTypesUnifyWithTypes(const ValueRangeLike &values, ArrayRef types) { + return llvm::all_of_zip(values, types, [](auto value, Type type) { + return typesUnify(value.getType(), type); + }); +} + +/// Replace any AffineMap-backed array dimensions nested within `type` with wildcard `?` dims. +/// +/// This preserves the overall array nesting while erasing only the affine-map dimensions that +/// cannot always be witnessed after flattening a POD leaf array into a split array value. +static Type replaceAffineMapArrayDimsWithWildcards(Type type) { + auto arrTy = llvm::dyn_cast(type); + if (!arrTy) { + return type; + } + + Builder builder(arrTy.getContext()); + SmallVector dims; + dims.reserve(arrTy.getDimensionSizes().size()); + for (Attribute dimSize : arrTy.getDimensionSizes()) { + if (llvm::isa(dimSize)) { + dims.push_back(builder.getIndexAttr(ShapedType::kDynamic)); + } else { + dims.push_back(dimSize); + } + } + + return arrTy.cloneWith(replaceAffineMapArrayDimsWithWildcards(arrTy.getElementType()), dims); +} + +/// Return the wildcard-backed storage split type for one flattened POD leaf. +/// +/// The precise split type preserves the original affine maps in the flattened leaf array. The +/// storage split type uses the same outer shape but replaces hidden leaf-array affine dims with +/// `?` until a matching instantiation can be recovered from concrete leaf-array values. +static ArrayType getSplitPodArrayStorageType(ArrayType arrTy, ArrayRef recordChain) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + Type leafType = getFlattenedTypeAlongPath(elemPodTy, recordChain); + return flattenArrayElementType(arrTy, replaceAffineMapArrayDimsWithWildcards(leafType)); +} + +/// Create an array value that callers can fully initialize via explicit writes or inserts. +/// +/// Use `llzk.nondet` as the base when affine-map dimensions are present because `array.new` +/// cannot carry both inline elements and affine-map instantiation operands. +inline static Value createWritableArrayValue(OpBuilder &bldr, Location loc, ArrayType arrTy) { + if (hasAffineMapAttr(arrTy)) { + return bldr.create(loc, arrTy); + } else { + return bldr.create(loc, arrTy); + } +} + +/// Return `true` iff two recovered array instantiations can be rebuilt identically. +static bool equivalentArrayInstantiationInfo( + const ArrayInstantiationInfo &lhs, const ArrayInstantiationInfo &rhs +) { + if (lhs.numDimsPerMap != rhs.numDimsPerMap || + lhs.mapOperandStorage.size() != rhs.mapOperandStorage.size()) { + return false; + } + + for (auto [lhsGroup, rhsGroup] : llvm::zip_equal(lhs.mapOperandStorage, rhs.mapOperandStorage)) { + if (lhsGroup != rhsGroup) { + return false; + } + } + + return true; +} + +/// Describe whether a set of leaf arrays shares one recoverable instantiation. +enum class CommonArrayInstantiationStatus : std::uint8_t { + unavailable, + inferred, + conflict, +}; + +/// Recover a single shared affine-map instantiation from all of `values`, if one exists. +/// +/// Returns `inferred` when every value resolves to the same concrete `array.new` +/// instantiation, `unavailable` when any value has no recoverable witness, and `conflict` +/// when the recovered instantiations disagree. +static CommonArrayInstantiationStatus +inferCommonArrayInstantiation(ArrayRef values, ArrayInstantiationInfo &result) { + bool initialized = false; + for (Value value : values) { + std::optional info = tryGetArrayInstantiationInfo(value); + if (!info) { + return CommonArrayInstantiationStatus::unavailable; + } + + if (!initialized) { + result = std::move(*info); + initialized = true; + continue; + } + + if (!equivalentArrayInstantiationInfo(result, *info)) { + return CommonArrayInstantiationStatus::conflict; + } + } + + return initialized ? CommonArrayInstantiationStatus::inferred + : CommonArrayInstantiationStatus::unavailable; +} + +/// Generate `arith.constant` indices for one static array element position. +static SmallVector genArrayIndexConstants(OpBuilder &bldr, Location loc, ArrayAttr index) { + SmallVector indices; + for (Attribute attr : index) { + assert(llvm::isa(attr) && "array index must be an integer attribute"); + indices.push_back(bldr.create(loc, llvm::cast(attr))); + } + return indices; +} + +/// Return the type produced by selecting `numIndices` leading dimensions from `arrTy`. +static Type getArraySelectionType(ArrayType arrTy, size_t numIndices) { + assert(numIndices <= arrTy.getDimensionSizes().size() && "cannot select past the array rank"); + if (numIndices == arrTy.getDimensionSizes().size()) { + return arrTy.getElementType(); + } + return ArrayType::get(arrTy.getElementType(), arrTy.getDimensionSizes().drop_front(numIndices)); +} + +/// Create an `array.read` or `array.extract` for one concrete element or subarray. +static Value genArrayRead(OpBuilder &bldr, Location loc, Value arrayRef, ArrayRef indices) { + Type t = arrayRef.getType(); + assert(llvm::isa(t) && "array access must target an array type"); + ArrayType arrTy = llvm::cast(t); + if (indices.size() == arrTy.getDimensionSizes().size()) { + return bldr.create(loc, arrTy.getElementType(), arrayRef, indices); + } + return bldr.create( + loc, llvm::cast(getArraySelectionType(arrTy, indices.size())), arrayRef, indices + ); +} + +inline static Value genArrayRead(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index) { + SmallVector indices = genArrayIndexConstants(bldr, loc, index); + return genArrayRead(bldr, loc, arrayRef, indices); +} + +/// Create an `array.write` or `array.insert` for one concrete element or subarray. +static void +genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayRef indices, Value value) { + Type t = arrayRef.getType(); + assert(llvm::isa(t) && "array access must target an array type"); + ArrayType arrTy = llvm::cast(t); + value = castValueToTypeIfNeeded(bldr, loc, value, getArraySelectionType(arrTy, indices.size())); + if (indices.size() == arrTy.getDimensionSizes().size()) { + bldr.create(loc, arrayRef, indices, value); + return; + } + assert(llvm::isa(value.getType()) && "subarray insertion requires an array value"); + bldr.create(loc, arrayRef, indices, value); +} + +inline static void +genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index, Value value) { + SmallVector indices = genArrayIndexConstants(bldr, loc, index); + genArrayWrite(bldr, loc, arrayRef, indices, value); +} + +/// Strip compatibility casts introduced while threading POD-derived array values through rewrites. +static Value peelUnifiableCasts(Value value) { + while (auto cast = value.getDefiningOp()) { + value = cast.getInput(); + } + return value; +} + +/// Collect split leaf arrays that are already available for an aggregate array-of-POD value. +/// +/// This peels compatibility casts and forwards through a dominating same-record `pod.write` so +/// nested POD scalarization can reuse the split-array representation already produced elsewhere in +/// the pass instead of re-materializing dynamic arrays element-by-element. +static bool tryCollectDirectSplitPodArrayLeafValues( + Value arrayValue, ArrayType arrTy, ArrayRef splitTypes, SmallVectorImpl &leafArrays +) { + arrayValue = peelUnifiableCasts(arrayValue); + + if (auto cast = arrayValue.getDefiningOp()) { + if (cast->getNumResults() != 1 || cast.getResult(0).getType() != arrTy || + cast->getNumOperands() != splitTypes.size()) { + return false; + } + + leafArrays.reserve(splitTypes.size()); + for (auto [operand, splitType] : llvm::zip_equal(cast.getOperands(), splitTypes)) { + if (operand.getType() != splitType) { + return false; + } + leafArrays.push_back(operand); + } + return true; + } + + if (ReadPodOp readOp = arrayValue.getDefiningOp()) { + if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { + return tryCollectDirectSplitPodArrayLeafValues( + writeOp.getValue(), arrTy, splitTypes, leafArrays + ); + } + } + + return false; +} + +/// Return whether `op` is preceded in its block by a write to `podRef.recordName`. +static bool hasEarlierWriteToRecordInBlock(Operation *op, Value podRef, StringAttr recordName) { + for (Operation &candidate : *op->getBlock()) { + if (&candidate == op) { + return false; + } + if (auto writeOp = dyn_cast(&candidate)) { + if (isSamePodRecord(writeOp, podRef, recordName)) { + return true; + } + } else if (hasNestedWriteToRecord(candidate, podRef, recordName)) { + return true; + } + } + return false; +} + +/// Return whether the read is preceded by a write to the same pod record within its block. +static bool hasEarlierWriteInBlock(ReadPodOp readOp) { + return hasEarlierWriteToRecordInBlock( + readOp.getOperation(), readOp.getPodRef(), readOp.getRecordNameAttr() + ); +} + +/// Return whether `op` is preceded in its block by any write to `podRef`. +static bool hasEarlierWriteToPodInBlock(Operation *op, Value podRef) { + for (Operation &candidate : *op->getBlock()) { + if (&candidate == op) { + return false; + } + if (auto writeOp = dyn_cast(&candidate)) { + if (writeOp.getPodRef() == podRef) { + return true; + } + } else if (hasNestedWriteToPod(candidate, podRef)) { + return true; + } + } + return false; +} + +/// Return `true` iff `readOp` names a fresh pod record that has not been initialized or written. +static bool isFreshUnwrittenPodRead(ReadPodOp readOp) { + NewPodOp newPod = readOp.getPodRef().getDefiningOp(); + if (!newPod) { + return false; + } + auto isReadOpRecordName = [&readOp](Attribute attr) { + return attr == readOp.getRecordNameAttr(); + }; + return llvm::none_of(newPod.getInitializedRecords(), isReadOpRecordName) && + !hasEarlierWriteInBlock(readOp); +} + +/// Return `true` iff `value` is an unwritten array-of-POD field read from a fresh `pod.new`. +static bool isFreshUnwrittenPodArrayRead(Value value) { + value = peelUnifiableCasts(value); + ReadPodOp readOp = value.getDefiningOp(); + return readOp && splittablePodArray(readOp.getType()) && isFreshUnwrittenPodRead(readOp); +} + +/// Read one flattened POD leaf, including leaves that live inside an array-of-POD record. +static Value +genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef recordChain) { + if (recordChain.empty()) { + return value; + } + + Type valueType = value.getType(); + if (llvm::isa(valueType)) { + Value nextValue = genRead(bldr, loc, value, recordChain.front()); + return genReadAlongPath(bldr, loc, nextValue, recordChain.drop_front()); + } + + if (ArrayType arrTy = splittablePodArray(valueType)) { + auto splitArrTy = llvm::cast(getFlattenedTypeAlongPath(valueType, recordChain)); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + auto *splitIt = llvm::find(splitIds, RecordChain(recordChain)); + assert(splitIt != splitIds.end() && "record path must name a flattened POD array leaf"); + size_t splitIdx = std::distance(splitIds.begin(), splitIt); + + SmallVector leafArrays; + if (tryCollectDirectSplitPodArrayLeafValues(value, arrTy, splitTypes, leafArrays)) { + return leafArrays[splitIdx]; + } + + Value strippedValue = peelUnifiableCasts(value); + if (strippedValue.getDefiningOp()) { + auto splitLeafReads = + bldr.create(loc, TypeRange(splitTypes), strippedValue); + return splitLeafReads.getResult(splitIdx); + } + + if (isFreshUnwrittenPodArrayRead(value)) { + return createWritableArrayValue(bldr, loc, splitArrTy); + } + + if (!arrTy.hasStaticShape()) { + llvm_unreachable( + "non-static nested array-of-POD scalarization requires split-array backing or an " + "uninitialized pod field" + ); + } + + auto subIndices = arrTy.getSubelementIndices(); + assert(subIndices && "static-shape arrays must provide subelement indices"); + + Value splitArray = createWritableArrayValue(bldr, loc, splitArrTy); + for (ArrayAttr index : *subIndices) { + Value element = genArrayRead(bldr, loc, value, index); + Value leafValue = genReadAlongPath(bldr, loc, element, recordChain); + genArrayWrite(bldr, loc, splitArray, index, leafValue); + } + return splitArray; + } + + llvm_unreachable("record path cannot continue through a non-POD leaf"); +} + +/// Read a flattened POD leaf by following each record name in `recordChain`. +inline static Value +genReadAlongPath(OpBuilder &bldr, Location loc, Value podRef, const RecordChain &recordChain) { + return genReadAlongPath(bldr, loc, podRef, ArrayRef(recordChain.nameList)); +} + +/// Reconstruct a POD record from the leaf values collected while splitting nested accesses. +static Value rebuildFlattenedPodRecord( + OpBuilder &bldr, Location loc, Type recordType, SmallVectorImpl &recordChain, + const DenseMap &leafValues +) { + if (PodType nestedPodTy = dyn_cast(recordType)) { + NewPodOp nestedPod = bldr.create(loc, nestedPodTy); + for (RecordAttr record : nestedPodTy.getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(bldr, loc, record.getType(), recordChain, leafValues); + genWrite(bldr, loc, nestedPod, record.getName(), recordValue); + recordChain.pop_back(); + } + return nestedPod; + } + + if (ArrayType arrTy = splittablePodArray(recordType)) { + if (!arrTy.hasStaticShape()) { + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector leafArrays; + leafArrays.reserve(splitIds.size()); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + SmallVector fullChain(recordChain.begin(), recordChain.end()); + llvm::append_range(fullChain, id.nameList); + auto it = leafValues.find(RecordChain(fullChain)); + assert(it != leafValues.end() && "missing flattened POD array leaf value"); + leafArrays.push_back(castValueToTypeIfNeeded(bldr, loc, it->second, splitType)); + } + + return bldr.create(loc, TypeRange {arrTy}, leafArrays) + .getResult(0); + } + + auto elemPodTy = llvm::cast(arrTy.getElementType()); + auto subIndices = arrTy.getSubelementIndices(); + assert(subIndices && "static-shape arrays must provide subelement indices"); + + Value rebuiltArray = bldr.create(loc, arrTy); + for (ArrayAttr index : *subIndices) { + DenseMap elementLeafValues; + SmallVector elementRecordChain; + forEachPodLeaf(elemPodTy, elementRecordChain, [&](RecordChain id, Type) { + SmallVector fullChain(recordChain.begin(), recordChain.end()); + llvm::append_range(fullChain, id.nameList); + auto it = leafValues.find(RecordChain(fullChain)); + assert(it != leafValues.end() && "missing flattened POD array leaf value"); + elementLeafValues[id] = genArrayRead(bldr, loc, it->second, index); + }); + + NewPodOp elementPod = bldr.create(loc, elemPodTy); + SmallVector nestedChain; + for (RecordAttr record : elemPodTy.getRecords()) { + nestedChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(bldr, loc, record.getType(), nestedChain, elementLeafValues); + genWrite(bldr, loc, elementPod, record.getName(), recordValue); + nestedChain.pop_back(); + } + genArrayWrite(bldr, loc, rebuiltArray, index, elementPod); + } + return rebuiltArray; + } + + auto it = leafValues.find(RecordChain(recordChain)); + assert(it != leafValues.end() && "missing flattened POD leaf value"); + return it->second; +} + +using VirtualPodLeafMap = DenseMap; +using VirtualPodValueMap = DenseMap; +using DeferredPodArrayLeafMap = DenseMap>; + +/// Return the flattened leaf values for `podValue` when it is tracked as a virtual POD. +static const VirtualPodLeafMap * +lookupVirtualPodLeafMap(Value podValue, const VirtualPodValueMap &virtualPods) { + auto it = virtualPods.find(podValue); + return it != virtualPods.end() ? &it->second : nullptr; +} + +/// Collect flattened POD leaf values in canonical traversal order. +static SmallVector orderedVirtualPodLeafValues( + PodType podTy, Location loc, OpBuilder &bldr, const VirtualPodLeafMap &leafValues +) { + SmallVector orderedValues; + SmallVector recordChain; + forEachPodLeaf( + podTy, recordChain, + [&leafValues, &orderedValues, &bldr, loc](const RecordChain &id, Type leafType) { + auto it = leafValues.find(id); + assert(it != leafValues.end() && "missing virtual POD leaf value"); + orderedValues.push_back(castValueToTypeIfNeeded(bldr, loc, it->second, leafType)); + } + ); + return orderedValues; +} + +/// Create a POD-typed placeholder for virtual leaf storage tracked in `leafValues`. +/// +/// PODs that embed affine-map-parameterized arrays cannot always be represented by a bare +/// `pod.new` at this stage because there may be no op-local instantiation operands available. +/// Use an unrealized cast from the ordered leaf values for those cases; later rewrites consult +/// `virtualPods` directly, and only concrete `pod.new` placeholders require materialization. +static Value createVirtualPodPlaceholder( + OpBuilder &bldr, Location loc, PodType podTy, const VirtualPodLeafMap &leafValues +) { + if (!hasAffineMapAttr(podTy)) { + return bldr.create(loc, podTy); + } + + SmallVector orderedValues = orderedVirtualPodLeafValues(podTy, loc, bldr, leafValues); + return bldr.create(loc, TypeRange {podTy}, orderedValues) + .getResult(0); +} + +/// Materialize the tracked contents of a virtual POD into concrete `pod.write` operations. +inline static void +materializeVirtualPod(OpBuilder &bldr, NewPodOp pod, const VirtualPodLeafMap &leafValues) { + Location loc = pod.getLoc(); + PodType podTy = pod.getType(); + SmallVector recordChain; + for (RecordAttr record : podTy.getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(bldr, loc, record.getType(), recordChain, leafValues); + genWrite(bldr, loc, pod, record.getName(), recordValue); + recordChain.pop_back(); + } +} + +/// Return the latest same-block operation that defines one of `leafValues`, or `pod` itself. +/// +/// Virtual PODs created from split block arguments can be updated later with scalar values defined +/// after the placeholder. Replaying the deferred writes immediately after the placeholder can then +/// violate SSA dominance. Materializing after the latest same-block leaf definition preserves the +/// original write ordering for these straight-line updates while keeping the fallback local. +static Operation * +findVirtualPodMaterializationAnchor(NewPodOp pod, const VirtualPodLeafMap &leafValues) { + Operation *anchor = pod.getOperation(); + Block *block = anchor->getBlock(); + + for (const auto &it : leafValues) { + Operation *defOp = it.second.getDefiningOp(); + if (!defOp || defOp->getBlock() != block) { + continue; + } + if (anchor->isBeforeInBlock(defOp)) { + anchor = defOp; + } + } + + return anchor; +} + +/// Return `true` iff a read from a virtual POD can be resolved without materializing it. +static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &virtualPods) { + if (!lookupVirtualPodLeafMap(op.getPodRef(), virtualPods) || hasEarlierWriteInBlock(op) || + findNearestForwardableWriteInBlock(op)) { + return false; + } + Type recType = llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); + return llvm::isa(recType) || !splittablePodArray(recType); +} + +/// Return `true` iff step 2 should defer splitting this array read until POD-aware rewriting. +static bool shouldDeferPodArrayReadToStep3(ReadArrayOp op) { + return splittablePodArray(op.getArrRefType()) && + llvm::isa_and_present(op.getArrRef().getDefiningOp()); +} + +/// Return the suffixes to append to a function arg/result name when splitting the given type. +static SmallVector getSplitRecordNameSuffixes(Type type) { + SmallVector suffixes; + if (PodType pt = splittablePod(type)) { SmallVector recordChain; - flattenPodMemberIntoLeaves(op, podTy, recordChain, localRepMapRef, structSymbolTable, rewriter); + forEachPodLeaf(pt, recordChain, [&suffixes](const RecordChain &id, Type) { + std::string suffix; + llvm::raw_string_ostream os(suffix); + for (StringAttr recordName : id.nameList) { + os << '.' << recordName.getValue(); + } + suffixes.push_back(std::move(suffix)); + }); + } + return suffixes; +} + +// If the operand has PodType, add reads from all pod records to the `newOperands` list otherwise +// add the original operand to the list. +static void processInputOperand( + Location loc, Value operand, SmallVector &newOperands, + ConversionPatternRewriter &rewriter, Operation *userOp = nullptr, + const VirtualPodValueMap *virtualPods = nullptr +) { + if (PodType pt = splittablePod(operand.getType())) { + if (virtualPods) { + if (const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(operand, *virtualPods); + leafValues && (!userOp || !hasEarlierWriteToPodInBlock(userOp, operand))) { + llvm::append_range( + newOperands, orderedVirtualPodLeafValues(pt, loc, rewriter, *leafValues) + ); + return; + } + } + SmallVector recordChain; + forEachPodLeaf(pt, recordChain, [&](const RecordChain &id, Type) { + newOperands.push_back(genReadAlongPath(rewriter, loc, operand, id)); + }); + } else { + newOperands.push_back(operand); + } +} + +/// For each operand with PodType, add reads from all pod records in place of the original operand +/// and update the op to use the new operands. +static void processInputOperands( + ValueRange operands, MutableOperandRange outputOpRef, Operation *op, + ConversionPatternRewriter &rewriter, const VirtualPodValueMap *virtualPods = nullptr +) { + SmallVector newOperands; + for (Value v : operands) { + processInputOperand(op->getLoc(), v, newOperands, rewriter, op, virtualPods); + } + rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() { + outputOpRef.assign(ValueRange(newOperands)); + }); +} + +/// Update the tracked leaf values for one top-level POD record after a virtual `pod.write`. +static void updateVirtualPodRecordLeafValues( + Location loc, StringAttr recordName, Type recordType, Value recordValue, + const VirtualPodValueMap &virtualPods, ConversionPatternRewriter &rewriter, + VirtualPodLeafMap &leafValues +) { + SmallVector prefix {recordName}; + + if (PodType nestedPodTy = llvm::dyn_cast(recordType)) { + if (const VirtualPodLeafMap *nestedLeafValues = + lookupVirtualPodLeafMap(recordValue, virtualPods)) { + SmallVector nestedRecordChain; + forEachPodLeaf(nestedPodTy, nestedRecordChain, [&](const RecordChain &id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + leafValues[RecordChain(fullChain)] = nestedLeafValues->at(id); + }); + return; + } + + SmallVector nestedRecordChain; + forEachPodLeaf(nestedPodTy, nestedRecordChain, [&](const RecordChain &id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + leafValues[RecordChain(fullChain)] = genReadAlongPath(rewriter, loc, recordValue, id); + }); + return; + } + + if (ArrayType arrTy = splittablePodArray(recordType)) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + SmallVector nestedRecordChain; + forEachPodLeaf(elemPodTy, nestedRecordChain, [&](const RecordChain &id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + leafValues[RecordChain(fullChain)] = genReadAlongPath(rewriter, loc, recordValue, id); + }); + return; + } + + leafValues[RecordChain(prefix)] = castValueToTypeIfNeeded(rewriter, loc, recordValue, recordType); +} + +/// Register the dialects and operations that remain legal across the conversion-based stages. +inline static void baseTargetSetup(ConversionTarget &target) { + target.addLegalDialect< + LLZKDialect, array::ArrayDialect, boolean::BoolDialect, cast::CastDialect, + constrain::ConstrainDialect, component::StructDialect, felt::FeltDialect, + function::FunctionDialect, global::GlobalDialect, include::IncludeDialect, pod::PODDialect, + polymorphic::PolymorphicDialect, ram::RAMDialect, string::StringDialect, arith::ArithDialect, + scf::SCFDialect>(); + target.addLegalOp(); +} + +/// Rewrite pod-typed `llzk.nondet` allocations into explicit `pod.new` allocations so the rest of +/// the pass only needs to reason about POD storage through POD dialect operations. +class NondetToNewPod : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + NonDetOp nondetOp, OpAdaptor, ConversionPatternRewriter &rewriter + ) const override { + if (auto pt = dyn_cast(nondetOp.getType())) { + rewriter.replaceOpWithNewOp(nondetOp, pt); + return success(); + } + return failure(); + } +}; + +/// Prepare the module by replacing `llzk.nondet` pod allocation ops with `pod.new`. +static LogicalResult step0(ModuleOp modOp) { + MLIRContext *ctx = modOp.getContext(); + RewritePatternSet patterns {ctx}; + patterns.add(ctx); + ConversionTarget target {*ctx}; + + baseTargetSetup(target); + target.addDynamicallyLegalOp([](NonDetOp op) { return !isa(op.getType()); }); + + return applyFullConversion(modOp, target, std::move(patterns)); +} + +/// new member name and type +using MemberInfo = std::pair; +/// original nested pod record name chain -> split scalar member info +using LocalMemberReplacementMap = DenseMap; +/// struct -> original pod-type member name -> LocalMemberReplacementMap +using MemberReplacementMap = DenseMap>; + +/// Build a flattened struct-member name like `member_outer_inner_leaf`. +static StringAttr +getFlattenedMemberName(MLIRContext *ctx, StringAttr memberName, ArrayRef recordChain) { + std::string flatName; + llvm::raw_string_ostream os(flatName); + os << memberName.getValue(); + for (StringAttr recordName : recordChain) { + os << '_' << recordName.getValue(); + } + return StringAttr::get(ctx, flatName); +} + +/// Recursively create scalar leaf members for a POD-typed struct member. +static void flattenPodMemberIntoLeaves( + MemberDefOp originalMember, PodType podTy, SmallVectorImpl &recordChain, + LocalMemberReplacementMap &localRepMapRef, SymbolTable &structSymbolTable, + ConversionPatternRewriter &rewriter +) { + forEachPodLeaf(podTy, recordChain, [&](const RecordChain &id, Type ty) { + StringAttr name = getFlattenedMemberName( + originalMember.getContext(), originalMember.getSymNameAttr(), id.nameList + ); + MemberDefOp newMember = rewriter.create( + originalMember.getLoc(), name, ty, originalMember.getSignal(), originalMember.getColumn() + ); + newMember.setPublicAttr(originalMember.hasPublicAttr()); + localRepMapRef[id] = std::make_pair(structSymbolTable.insert(newMember), ty); + }); +} + +/// Split a pod-typed struct member definition into one scalar member definition per POD record. +/// +/// The replacement map records the fresh member symbols so later rewrites can retarget +/// `struct.readm` and `struct.writem` operations to the split members. +class SplitPodInMemberDefOp : public OpConversionPattern { + SymbolTableCollection &tables; + MemberReplacementMap &repMapRef; + +public: + SplitPodInMemberDefOp( + MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap) {} + + inline static bool legal(MemberDefOp op) { return !splittablePod(op.getType()); } + + LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); } + + void + rewrite(MemberDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructDefOp inStruct = op->getParentOfType(); + assert(inStruct); + LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()]; + + PodType podTy = llvm::cast(adaptor.getType()); // safe per legal() check + + SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct); + SmallVector recordChain; + flattenPodMemberIntoLeaves(op, podTy, recordChain, localRepMapRef, structSymbolTable, rewriter); + rewriter.eraseOp(op); + } +}; + +/// Split an array-of-POD struct member definition into one parallel array member per POD leaf. +class SplitPodArrayInMemberDefOp : public OpConversionPattern { + SymbolTableCollection &tables; + MemberReplacementMap &repMapRef; + +public: + SplitPodArrayInMemberDefOp( + MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap) {} + + inline static bool legal(MemberDefOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); } + + void + rewrite(MemberDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructDefOp inStruct = op->getParentOfType(); + assert(inStruct); + LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()]; + + ArrayType arrTy = llvm::cast(adaptor.getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(arrTy); + rewriter.modifyOpInPlace(op, [&]() { op.setType(carrierTy); }); + localRepMapRef[RecordChain()] = std::make_pair(op.getSymNameAttr(), carrierTy); + return; + } + + SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + StringAttr name = getFlattenedMemberName(op.getContext(), op.getSymNameAttr(), id.nameList); + MemberDefOp newMember = rewriter.create( + op.getLoc(), name, splitType, op.getSignal(), op.getColumn() + ); + newMember.setPublicAttr(op.hasPublicAttr()); + localRepMapRef[id] = std::make_pair(structSymbolTable.insert(newMember), splitType); + } + rewriter.eraseOp(op); + } +}; + +/// Replace direct `PodType` struct members with scalar members and arrays-of-POD with parallel +/// array members named after the corresponding POD leaf. +static LogicalResult +step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) { + MLIRContext *ctx = modOp.getContext(); + + RewritePatternSet patterns(ctx); + + patterns.add(ctx, symTables, memberRepMap); + + ConversionTarget target(*ctx); + baseTargetSetup(target); + target.addDynamicallyLegalOp([](MemberDefOp op) { + return SplitPodInMemberDefOp::legal(op) && SplitPodArrayInMemberDefOp::legal(op); + }); + + LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split pod-type and array-of-pod members\n";); + return applyFullConversion(modOp, target, std::move(patterns)); +} + +/// Type converter that replaces each array-of-POD type with one parallel array type per POD leaf. +/// +/// Besides splitting result types, this also materializes compatibility casts between precise +/// split array types and wildcard-backed storage split types when target or block-argument +/// conversion needs to cross that boundary. +class PodArrayTypeConverter : public TypeConverter { +public: + PodArrayTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion( + [](ArrayType arrTy, SmallVectorImpl &results) -> std::optional { + if (!splittablePodArray(arrTy)) { + return std::nullopt; + } + convertPodArrayTypeTo(arrTy, results); + return success(); + } + ); + + auto materializeCast = [](OpBuilder &bldr, Type targetType, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1 || !typesUnify(inputs.front().getType(), targetType)) { + return {}; + } + return castValueToTypeIfNeeded(bldr, loc, inputs.front(), targetType); + }; + addTargetMaterialization(materializeCast); + addArgumentMaterialization(materializeCast); + addSourceMaterialization(materializeCast); + } +}; + +/// Split `llzk.nondet` of array-of-POD type into one `llzk.nondet` per parallel leaf array. +class SplitPodArrayNonDetOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(NonDetOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult match(NonDetOp op) const override { return failure(legal(op)); } + + void rewrite(NonDetOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector splitTypes; + splitPodArrayTypeTo(op.getType(), splitTypes); + if (splitTypes.empty()) { + rewriter.replaceOpWithNewOp( + op, getZeroLeafPodArrayShapeCarrierType(llvm::cast(op.getType())) + ); + return; + } + SmallVector replacements; + replacements.reserve(splitTypes.size()); + for (Type splitType : splitTypes) { + replacements.push_back(rewriter.create(op.getLoc(), splitType)); + } + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + } +}; + +/// Split `array.new` of array-of-POD type into one `array.new` per parallel leaf array. +/// +/// For each leaf, the precise split type preserves the original affine maps in the flattened leaf +/// array. When hidden leaf-array affine dims have no direct witness, the rewrite may first build a +/// wildcard-backed storage split type with the same outer shape and cast back to the precise type. +/// +/// Uninitialized `array.new` uses that storage fallback directly when needed. Explicit-element +/// `array.new` tries to infer one shared affine-map instantiation from all leaf arrays so it can +/// materialize the precise split type immediately. If different elements imply conflicting +/// instantiations, the rewrite remains a hard failure. +class SplitPodArrayCreateArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(CreateArrayOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult matchAndRewrite( + CreateArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + ArrayType arrTy = llvm::cast(op.getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(arrTy); + if (adaptor.getMapOperands().empty()) { + rewriter.replaceOpWithNewOp(op, carrierTy); + return success(); + } + + SmallVector> mapOperandStorage; + SmallVector mapOperands; + mapOperandStorage.reserve(adaptor.getMapOperands().size()); + mapOperands.reserve(adaptor.getMapOperands().size()); + for (ArrayRef mapOperandGroup : adaptor.getMapOperands()) { + mapOperandStorage.push_back(flattenConvertedValues(mapOperandGroup)); + } + for (const SmallVector &values : mapOperandStorage) { + mapOperands.push_back(values); + } + rewriter.replaceOpWithNewOp( + op, carrierTy, mapOperands, op.getNumDimsPerMapAttr() + ); + return success(); + } + + SmallVector replacements; + replacements.reserve(splitTypes.size()); + DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr(); + if (isNullOrEmpty(numDimsPerMap)) { + if (adaptor.getElements().empty()) { + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + ArrayType preciseSplitType = llvm::cast(splitType); + ArrayType storageSplitType = getSplitPodArrayStorageType(arrTy, id.nameList); + Value splitArray = rewriter.create(op.getLoc(), storageSplitType); + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitArray, preciseSplitType) + ); + } + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } + + auto elementIndices = arrTy.getSubelementIndices(); + assert(elementIndices && "array.new with explicit elements requires a static array shape"); + assert( + elementIndices->size() == adaptor.getElements().size() && + "array.new element count must match the outer array cardinality" + ); + + // Inline initializers are linearized only across the original outer array dimensions. When + // a flattened POD leaf is itself an array, populate the rewritten split array one outer + // element at a time so each leaf array becomes a subarray insert rather than a malformed + // inline operand to the flattened `array.new`. + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + ArrayType preciseSplitType = llvm::cast(splitType); + ArrayType storageSplitType = getSplitPodArrayStorageType(arrTy, id.nameList); + + SmallVector leafValues; + leafValues.reserve(adaptor.getElements().size()); + for (ValueRange elementRange : adaptor.getElements()) { + Value element = getSingleConvertedValue(elementRange); + leafValues.push_back(genReadAlongPath(rewriter, op.getLoc(), element, id)); + } + + ArrayType materializedType = storageSplitType; + Value splitArray; + if (storageSplitType != preciseSplitType) { + ArrayInstantiationInfo instantiationInfo; + switch (inferCommonArrayInstantiation(leafValues, instantiationInfo)) { + case CommonArrayInstantiationStatus::conflict: + // TODO: this POD could be promoted to a complete `struct.def` but that's not easy. + op.emitOpError( + "with POD elements having conflicting affine map instantiations cannot be promoted " + "to higher dimensional array" + ); + return failure(); + case CommonArrayInstantiationStatus::inferred: { + materializedType = preciseSplitType; + SmallVector mapOperands; + mapOperands.reserve(instantiationInfo.mapOperandStorage.size()); + for (const SmallVector &values : instantiationInfo.mapOperandStorage) { + mapOperands.push_back(values); + } + splitArray = rewriter.create( + op.getLoc(), materializedType, mapOperands, instantiationInfo.numDimsPerMap + ); + break; + } + case CommonArrayInstantiationStatus::unavailable: + break; + } + } + + if (!splitArray) { + splitArray = createWritableArrayValue(rewriter, op.getLoc(), materializedType); + } + + for (auto [index, leafValue] : llvm::zip_equal(*elementIndices, leafValues)) { + genArrayWrite(rewriter, op.getLoc(), splitArray, index, leafValue); + } + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitArray, preciseSplitType) + ); + } + } else { + SmallVector> mapOperandStorage; + SmallVector mapOperands; + mapOperandStorage.reserve(adaptor.getMapOperands().size()); + mapOperands.reserve(adaptor.getMapOperands().size()); + for (ArrayRef mapOperandGroup : adaptor.getMapOperands()) { + mapOperandStorage.push_back(flattenConvertedValues(mapOperandGroup)); + } + for (const SmallVector &values : mapOperandStorage) { + mapOperands.push_back(values); + } + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + ArrayType preciseSplitType = llvm::cast(splitType); + ArrayType storageSplitType = getSplitPodArrayStorageType(arrTy, id.nameList); + Value splitArray = rewriter.create( + op.getLoc(), storageSplitType, mapOperands, numDimsPerMap + ); + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitArray, preciseSplitType) + ); + } + } + + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + +/// Split `array.read` from an array-of-POD into leaf reads plus local POD reconstruction. +class SplitPodArrayReadArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ReadArrayOp op) { + return !splittablePodArray(op.getArrRefType()) || shouldDeferPodArrayReadToStep3(op); + } + + LogicalResult matchAndRewrite( + ReadArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + ArrayType arrTy = op.getArrRefType(); + PodType podTy = llvm::cast(arrTy.getElementType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + rewriter.replaceOpWithNewOp(op, podTy); + return success(); + } + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + NewPodOp pod = rewriter.create(op.getLoc(), podTy); + DenseMap leafValues; + for (auto [id, splitArrRange] : llvm::zip_equal(splitIds, adaptor.getArrRef())) { + leafValues[id] = + genArrayRead(rewriter, op.getLoc(), getSingleConvertedValue(splitArrRange), indices); + } + + SmallVector recordChain; + for (RecordAttr record : podTy.getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = rebuildFlattenedPodRecord( + rewriter, op.getLoc(), record.getType(), recordChain, leafValues + ); + genWrite(rewriter, op.getLoc(), pod, record.getName(), recordValue); + recordChain.pop_back(); + } + rewriter.replaceOp(op, pod); + return success(); + } +}; + +/// Split `array.write` to an array-of-POD into one write per parallel leaf array. +class SplitPodArrayWriteArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(WriteArrayOp op) { return !splittablePodArray(op.getArrRefType()); } + + LogicalResult matchAndRewrite( + WriteArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + ArrayType arrTy = op.getArrRefType(); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + rewriter.eraseOp(op); + return success(); + } + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + Value podValue = getSingleConvertedValue(adaptor.getRvalue()); + for (auto [id, splitArrRange, splitType] : + llvm::zip_equal(splitIds, adaptor.getArrRef(), splitTypes)) { + Value leafValue = genReadAlongPath(rewriter, op.getLoc(), podValue, id); + genArrayWrite( + rewriter, op.getLoc(), getSingleConvertedValue(splitArrRange), indices, leafValue + ); + } + rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite array-of-POD function signatures to use one parallel array per POD leaf. +class SplitPodArrayInFuncDefOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(FuncDefOp op) { + return !containsSplittablePodArrayType(op.getArgumentTypes()) && + !containsSplittablePodArrayType(op.getResultTypes()); + } + + LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); } + + LogicalResult + matchAndRewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + const auto *tyConv = getTypeConverter(); + assert(tyConv && "expected pod-array type converter"); + + FunctionType oldTy = op.getFunctionType(); + TypeConverter::SignatureConversion inputConversion(oldTy.getNumInputs()); + if (failed(tyConv->convertSignatureArgs(oldTy.getInputs(), inputConversion))) { + return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod inputs"); + } + + SmallVector newResults; + if (failed(tyConv->convertTypes(oldTy.getResults(), newResults))) { + return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod results"); + } + + if (!op.getBody().empty() && + failed(rewriter.convertRegionTypes(&op.getBody(), *tyConv, &inputConversion))) { + return rewriter.notifyMatchFailure(op, "failed to convert function body block arguments"); + } + + SmallVector originalInputIdxToSize, originalResultIdxToSize; + SmallVector newInputs = convertPodArrayTypes(oldTy.getInputs(), &originalInputIdxToSize); + SmallVector newResultsWithSizeInfo = + convertPodArrayTypes(oldTy.getResults(), &originalResultIdxToSize); + assert( + newResultsWithSizeInfo == newResults && + "expected array-of-pod type conversion to match function result attr replication" + ); + SplitFunctionNameInfo inputNameInfo = + collectSplitFunctionNameInfo(op.getArgumentTypes(), [&](unsigned i) { + return op.getArgNameAttr(i); + }, getSplitPodArrayRecordNameSuffixes); + ArrayAttr resultAttrs = op.getAllResultAttrs(); + SplitFunctionNameInfo resultNameInfo = + collectSplitFunctionNameInfo(op.getResultTypes(), [resultAttrs](unsigned i) { + return getAttrAtIndexWithName(resultAttrs, i, RES_NAME_ATTR_NAME); + }, getSplitPodArrayRecordNameSuffixes); + + rewriter.modifyOpInPlace(op, [&]() { + op.setFunctionType(FunctionType::get(op.getContext(), newInputs, newResults)); + if (ArrayAttr newArgAttrs = replicateFunctionNameAttrsAsNeeded( + op.getArgAttrsAttr(), originalInputIdxToSize, newInputs, ARG_NAME_ATTR_NAME, + inputNameInfo.originalNames, inputNameInfo.existingNames, + inputNameInfo.splitNameSuffixes + )) { + op.setArgAttrsAttr(newArgAttrs); + } + if (ArrayAttr newResAttrs = replicateFunctionNameAttrsAsNeeded( + op.getResAttrsAttr(), originalResultIdxToSize, newResults, RES_NAME_ATTR_NAME, + resultNameInfo.originalNames, resultNameInfo.existingNames, + resultNameInfo.splitNameSuffixes + )) { + op.setResAttrsAttr(newResAttrs); + } + }); + return success(); + } +}; + +/// Append the split leaf-array values for one step-2 operand. +/// +/// When dialect conversion has already produced the parallel leaf arrays, reuse those converted +/// values directly. Otherwise derive the split arrays from the original aggregate operand so users +/// like `poly.unifiable_cast` and `function.return` can still flatten a raw `pod.read` of an array +/// field. +static void collectSplitPodArrayOperandValues( + Location loc, Value originalOperand, ValueRange convertedValues, + SmallVectorImpl &newOperands, ConversionPatternRewriter &rewriter +) { + ArrayType arrTy = splittablePodArray(originalOperand.getType()); + if (!arrTy) { + llvm::append_range(newOperands, convertedValues); + return; + } + + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(arrTy); + if (!convertedValues.empty()) { + newOperands.push_back(castValueToTypeIfNeeded( + rewriter, loc, getSingleConvertedValue(convertedValues), carrierTy + )); + return; + } + newOperands.push_back(materializeArrayLengthCarrier(originalOperand, arrTy, loc, rewriter)); + return; + } + + auto isDirectAggregateToSplitCast = [&convertedValues, &splitTypes]() { + if (convertedValues.empty()) { + return false; + } + auto castOp = convertedValues.front().getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1 || + !splittablePodArray(castOp.getOperand(0).getType()) || + castOp->getNumResults() != splitTypes.size()) { + return false; + } + + return llvm::all_of(llvm::zip_equal(convertedValues, splitTypes), [&castOp](auto pair) { + Value convertedValue = std::get<0>(pair); + Type splitType = std::get<1>(pair); + return convertedValue.getDefiningOp() == castOp && + typesUnify(convertedValue.getType(), splitType); + }); + }; + + if (!isDirectAggregateToSplitCast() && allValueTypesUnifyWithTypes(convertedValues, splitTypes)) { + llvm::append_range(newOperands, convertedValues); + return; + } + + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + Value splitValue = genReadAlongPath(rewriter, loc, originalOperand, id); + newOperands.push_back(castValueToTypeIfNeeded(rewriter, loc, splitValue, splitType)); + } +} + +/// Rewrite array-of-POD `poly.unifiable_cast` into one leaf-array cast per split array. +class SplitPodArrayInUnifiableCastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(UnifiableCastOp op) { + return !splittablePodArray(op.getType()) && !splittablePodArray(op.getInput().getType()); + } + + LogicalResult matchAndRewrite( + UnifiableCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + ArrayType inputArrTy = splittablePodArray(op.getInput().getType()); + ArrayType resultArrTy = splittablePodArray(op.getType()); + + // When only the input is split, rebuild the aggregate input and preserve the original + // scalar/tvar result cast. + if (inputArrTy && !resultArrTy) { + SmallVector inputSplitTypes; + splitPodArrayTypeTo(inputArrTy, inputSplitTypes); + + SmallVector splitInputs; + collectSplitPodArrayOperandValues( + op.getLoc(), op.getInput(), adaptor.getInput(), splitInputs, rewriter + ); + if (inputSplitTypes.empty()) { + if (splitInputs.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected one shape carrier for zero-leaf array-of-pod cast input" + ); + } + } else if (splitInputs.size() != inputSplitTypes.size()) { + return rewriter.notifyMatchFailure( + op, "failed to collect one split input per array-of-pod cast leaf" + ); + } + + // `poly.unifiable_cast` to a non-array target cannot preserve all split leaf values in one + // SSA value without reintroducing aggregate array-of-POD materialization. + auto *it = llvm::find_if(splitInputs, [&op](Value v) { + return typesUnify(v.getType(), op.getType()); + }); + if (it == splitInputs.end()) { + return rewriter.notifyMatchFailure( + op, "failed to find split array leaf type compatible with cast target" + ); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), castValueToTypeIfNeeded(rewriter, op.getLoc(), *it, op.getType()) + ); + return success(); + } + + if (!inputArrTy) { + return rewriter.notifyMatchFailure( + op, "expected array-of-pod cast input when rewriting array-of-pod cast result" + ); + } + + SmallVector inputSplitIds; + SmallVector inputSplitTypes; + splitPodArrayTypeTo(inputArrTy, inputSplitTypes, &inputSplitIds); + + SmallVector resultSplitIds; + SmallVector resultSplitTypes; + splitPodArrayTypeTo(resultArrTy, resultSplitTypes, &resultSplitIds); + + if (inputSplitIds != resultSplitIds) { + return rewriter.notifyMatchFailure( + op, "array-of-pod cast changed POD leaf structure unexpectedly" + ); + } + + SmallVector splitInputs; + collectSplitPodArrayOperandValues( + op.getLoc(), op.getInput(), adaptor.getInput(), splitInputs, rewriter + ); + if (resultSplitTypes.empty()) { + if (splitInputs.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected one shape carrier for zero-leaf array-of-pod cast" + ); + } + rewriter.replaceOp( + op, castValueToTypeIfNeeded( + rewriter, op.getLoc(), splitInputs.front(), + getZeroLeafPodArrayShapeCarrierType(resultArrTy) + ) + ); + return success(); + } + if (splitInputs.size() != resultSplitTypes.size()) { + return rewriter.notifyMatchFailure( + op, "failed to collect one split input per array-of-pod cast leaf" + ); + } + + SmallVector replacements; + replacements.reserve(resultSplitTypes.size()); + for (auto [splitInput, resultSplitType] : llvm::zip_equal(splitInputs, resultSplitTypes)) { + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitInput, resultSplitType) + ); + } + + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + +/// Rewrite `function.return` to flatten any array-of-POD operands into their parallel arrays. +class SplitPodArrayInReturnOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ReturnOp op) { + return !containsSplittablePodArrayType(op.getOperands().getTypes()); + } + + LogicalResult matchAndRewrite( + ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + SmallVector newOperands; + for (auto [operand, convertedValues] : + llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { + collectSplitPodArrayOperandValues( + op.getLoc(), operand, convertedValues, newOperands, rewriter + ); + } + rewriter.replaceOpWithNewOp(op, ValueRange(newOperands)); + return success(); + } +}; + +/// Rewrite calls whose arguments or results contain arrays-of-POD to use the split signature. +class SplitPodArrayInCallOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(CallOp op) { + return !containsSplittablePodArrayType(op.getArgOperands().getTypes()) && + !containsSplittablePodArrayType(op.getResultTypes()); + } + + LogicalResult match(CallOp op) const override { return failure(legal(op)); } + + LogicalResult matchAndRewrite( + CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + const auto *tyConv = getTypeConverter(); + assert(tyConv && "expected pod-array type converter"); + + SmallVector newResultTypes; + if (failed(tyConv->convertTypes(op.getResultTypes(), newResultTypes))) { + return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod call results"); + } + + SmallVector> mapOperandStorage; + SmallVector mapOperands; + mapOperandStorage.reserve(adaptor.getMapOperands().size()); + mapOperands.reserve(adaptor.getMapOperands().size()); + for (ArrayRef mapOperandGroup : adaptor.getMapOperands()) { + mapOperandStorage.push_back(flattenConvertedValues(mapOperandGroup)); + } + for (const SmallVector &values : mapOperandStorage) { + mapOperands.push_back(values); + } + + SmallVector newArgOperands; + for (auto [operand, convertedValues] : + llvm::zip_equal(op.getArgOperands(), adaptor.getArgOperands())) { + collectSplitPodArrayOperandValues( + op.getLoc(), operand, convertedValues, newArgOperands, rewriter + ); + } + CallOp newCall = createCallPreservingInstantiationOperands( + op.getLoc(), newResultTypes, op, mapOperands, newArgOperands, rewriter + ); + + SmallVector> replacementStorage; + replacementStorage.reserve(op.getNumResults()); + auto newResultIt = newCall.getResults().begin(); + for (Type oldResultType : op.getResultTypes()) { + SmallVector convertedTypes; + (void)convertPodArrayTypeTo(oldResultType, convertedTypes); + SmallVector replacementsForResult; + replacementsForResult.reserve(convertedTypes.size()); + for (size_t i = 0; i < convertedTypes.size(); ++i) { + replacementsForResult.push_back(*newResultIt); + ++newResultIt; + } + replacementStorage.push_back(std::move(replacementsForResult)); + } + + SmallVector replacements; + replacements.reserve(replacementStorage.size()); + for (const SmallVector &values : replacementStorage) { + replacements.push_back(values); + } + rewriter.replaceOpWithMultiple(op, replacements); + return success(); + } +}; + +/// Rewrite `constrain.eq` over arrays-of-POD into one equality per parallel leaf array. +class SplitPodArrayInEmitEqualityOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(constrain::EmitEqualityOp op) { + return !containsSplittablePodArrayType(op->getOperandTypes()); + } + + LogicalResult matchAndRewrite( + constrain::EmitEqualityOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + if (ArrayType lhsTy = splittablePodArray(op.getLhs().getType()); + lhsTy && hasZeroLeafPodArraySplit(lhsTy)) { + Value lhsCarrier = + adaptor.getLhs().empty() + ? materializeArrayLengthCarrier(op.getLhs(), lhsTy, op.getLoc(), rewriter) + : getSingleConvertedValue(adaptor.getLhs()); + Value rhsCarrier = + adaptor.getRhs().empty() + ? materializeArrayLengthCarrier(op.getRhs(), lhsTy, op.getLoc(), rewriter) + : getSingleConvertedValue(adaptor.getRhs()); + for (size_t dim = 0, rank = lhsTy.getDimensionSizes().size(); dim < rank; ++dim) { + Value dimVal = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(llzk::checkedCast(dim)) + ); + Value lhsLen = rewriter.create(op.getLoc(), lhsCarrier, dimVal); + Value rhsLen = rewriter.create(op.getLoc(), rhsCarrier, dimVal); + rewriter.create(op.getLoc(), lhsLen, rhsLen); + } + rewriter.eraseOp(op); + return success(); + } + + if (adaptor.getLhs().size() != adaptor.getRhs().size()) { + return rewriter.notifyMatchFailure( + op, "expected array-of-pod equality operands to expand to the same number of leaves" + ); + } + + for (auto [lhs, rhs] : llvm::zip_equal(adaptor.getLhs(), adaptor.getRhs())) { + rewriter.create(op.getLoc(), lhs, rhs); + } + rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite `constrain.in` over arrays-of-POD into a shared-slice witness plus leaf equalities. +/// +/// After step 2 converts an array-of-POD into parallel leaf arrays, `constrain.in` can no longer be +/// left in place because it has no built-in 1:N operand rewrite. This pattern preserves the +/// original containment semantics by: +/// +/// 1. Expanding both operands into matching POD leaves. +/// 2. Computing how many leading lhs dimensions must be selected to match the rhs rank. +/// 3. Creating one nondeterministic index per selected dimension and constraining each index to be +/// in bounds using `array.len` and `constrain.eq` on the comparison results. +/// 4. Using that same index tuple for every leaf, reading a scalar leaf with `array.read` or +/// extracting an array leaf with `array.extract`. +/// 5. Emitting one `constrain.eq` per selected lhs leaf and rhs leaf, then erasing the original +/// `constrain.in`. +/// +/// Reusing the same nondeterministic indices across all leaves is essential: it guarantees that all +/// field equalities refer to the same POD element or subarray, rather than allowing different +/// leaves to match at different positions. +class SplitPodArrayInEmitContainmentOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(constrain::EmitContainmentOp op) { + return !containsSplittablePodArrayType(op->getOperandTypes()); + } + + /// Return the split scalar or leaf-array values representing one containment operand. + static SmallVector collectContainmentLeaves( + Location loc, Value originalOperand, ValueRange convertedValues, + ConversionPatternRewriter &rewriter + ) { + if (ArrayType arrTy = splittablePodArray(originalOperand.getType()); + arrTy && hasZeroLeafPodArraySplit(arrTy)) { + return {}; + } + + if (splittablePod(originalOperand.getType())) { + SmallVector podLeaves; + processInputOperand(loc, getSingleConvertedValue(convertedValues), podLeaves, rewriter); + return podLeaves; + } + + return SmallVector(convertedValues.begin(), convertedValues.end()); + } + + LogicalResult matchAndRewrite( + constrain::EmitContainmentOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + Location loc = op.getLoc(); + ArrayType lhsTy = op.getLhs().getType(); + Type rhsTy = op.getRhs().getType(); + + size_t lhsRank = lhsTy.getDimensionSizes().size(); + size_t rhsRank = 0; + if (auto rhsArrTy = llvm::dyn_cast(rhsTy)) { + rhsRank = rhsArrTy.getDimensionSizes().size(); + } + assert(lhsRank >= rhsRank && "constrain.in verifier should reject higher-rank rhs arrays"); + size_t selectedDims = lhsRank - rhsRank; + + SmallVector lhsLeaves = + collectContainmentLeaves(loc, op.getLhs(), adaptor.getLhs(), rewriter); + SmallVector rhsLeaves = + collectContainmentLeaves(loc, op.getRhs(), adaptor.getRhs(), rewriter); + if (lhsLeaves.size() != rhsLeaves.size()) { + return rewriter.notifyMatchFailure( + op, "expected array-of-pod containment operands to expand to the same number of leaves" + ); + } + + Value shapeCarrier = + lhsLeaves.empty() ? (adaptor.getLhs().empty() + ? materializeArrayLengthCarrier(op.getLhs(), lhsTy, loc, rewriter) + : getSingleConvertedValue(adaptor.getLhs())) + : adaptor.getLhs().front(); + Value zero = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value trueVal = rewriter.create( + loc, IntegerAttr::get(IntegerType::get(rewriter.getContext(), 1), 1) + ); + + SmallVector selectedIndices; + selectedIndices.reserve(selectedDims); + for (size_t dim = 0; dim < selectedDims; ++dim) { + Value idx = rewriter.create(loc, IndexType::get(rewriter.getContext())); + Value dimVal = rewriter.create( + loc, rewriter.getIndexAttr(llzk::checkedCast(dim)) + ); + Value dimLen = rewriter.create(loc, shapeCarrier, dimVal); + + Value nonNegative = rewriter.create(loc, arith::CmpIPredicate::sge, idx, zero); + rewriter.create(loc, nonNegative, trueVal); + + Value inRange = rewriter.create(loc, arith::CmpIPredicate::slt, idx, dimLen); + rewriter.create(loc, inRange, trueVal); + + selectedIndices.push_back(idx); + } + + for (auto [lhsLeaf, rhsLeaf] : llvm::zip_equal(lhsLeaves, rhsLeaves)) { + Value selectedLhs = lhsLeaf; + if (auto rhsLeafArrTy = llvm::dyn_cast(rhsLeaf.getType())) { + if (!selectedIndices.empty()) { + selectedLhs = + rewriter.create(loc, rhsLeafArrTy, lhsLeaf, selectedIndices); + } + } else { + selectedLhs = + rewriter.create(loc, rhsLeaf.getType(), lhsLeaf, selectedIndices); + } + rewriter.create(loc, selectedLhs, rhsLeaf); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Return an array value whose visible rank still matches the original `array.len` source. +/// +/// Split POD leaves always preserve the original outer dimensions, but array-valued leaves append +/// their own inner dimensions. Dynamic dimension indices must not be able to observe those extra +/// leaf-only dimensions, so when every converted leaf has higher rank this synthesizes a shape-only +/// carrier with the original rank instead. +static Value selectArrayLengthShapeSource( + ArrayLengthOp op, ValueRange convertedArrRefs, ConversionPatternRewriter &rewriter +) { + size_t originalRank = op.getArrRefType().getDimensionSizes().size(); + for (Value arrRef : convertedArrRefs) { + auto arrTy = llvm::dyn_cast(arrRef.getType()); + assert(arrTy && "converted array-of-POD operand must stay an array"); + if (arrTy.getDimensionSizes().size() == originalRank) { + return arrRef; + } + } + + return materializeArrayLengthCarrier(op.getArrRef(), op.getArrRefType(), op.getLoc(), rewriter); +} + +/// Replace `array.length` on an array-of-POD with an equivalent rank-preserving array value. +class SplitPodArrayLengthOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ArrayLengthOp op) { return !splittablePodArray(op.getArrRefType()); } + + LogicalResult matchAndRewrite( + ArrayLengthOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + Value arrRef = selectArrayLengthShapeSource(op, adaptor.getArrRef(), rewriter); + rewriter.replaceOpWithNewOp( + op, arrRef, getSingleConvertedValue(adaptor.getDim()) + ); + return success(); + } +}; + +/// Rebuild the current quantifier iterand from one read or extract per split POD-array leaf. +static Value rebuildSplitPodArrayQuantifierIterValue( + OpBuilder &bldr, Location loc, Type iterType, Value index, ArrayType sortType, + ValueRange convertedSort +) { + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(sortType, splitTypes, &splitIds); + if (splitTypes.empty()) { + return bldr.create(loc, llvm::cast(iterType)); + } + assert( + convertedSort.size() == splitIds.size() && + "converted quantifier sort must provide one value per POD-array leaf" + ); + + DenseMap leafValues; + for (auto [id, leafArray] : llvm::zip_equal(splitIds, convertedSort)) { + SmallVector indices {index}; + leafValues[id] = genArrayRead(bldr, loc, leafArray, indices); + } + + SmallVector recordChain; + return rebuildFlattenedPodRecord(bldr, loc, iterType, recordChain, leafValues); +} + +/// Lower a bool quantifier over an array-of-POD to an `scf.for` over the split leaf arrays. +template +static LogicalResult rewriteSplitPodArrayQuantifier( + QuantifierOp op, ValueRange convertedSort, ConversionPatternRewriter &rewriter, + bool initialValue +) { + ArrayType sortType = llvm::cast(op.getSort().getType()); + Location loc = op.getLoc(); + + Value shapeCarrier = convertedSort.empty() + ? materializeArrayLengthCarrier(op.getSort(), sortType, loc, rewriter) + : convertedSort.front(); + Value lowerBound = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value upperBound = rewriter.create(loc, shapeCarrier, lowerBound); + Value step = rewriter.create(loc, rewriter.getIndexAttr(1)); + Value init = rewriter.create( + loc, IntegerAttr::get(IntegerType::get(rewriter.getContext(), 1), initialValue ? 1 : 0) + ); + + auto loop = rewriter.create(loc, lowerBound, upperBound, step, ValueRange {init}); + loop->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + + Block &loopBody = *loop.getBody(); + if (!loopBody.empty()) { + rewriter.eraseOp(&loopBody.back()); + } + + rewriter.setInsertionPointToStart(&loopBody); + Value iterValue = rebuildSplitPodArrayQuantifierIterValue( + rewriter, loc, op.getBody()->getArgument(0).getType(), loop.getInductionVar(), sortType, + convertedSort + ); + + IRMapping mapping; + mapping.map(op.getBody()->getArgument(0), iterValue); + + for (Operation &nestedOp : op.getBody()->without_terminator()) { + rewriter.clone(nestedOp, mapping); + } + + auto yieldOp = llvm::cast(op.getBody()->getTerminator()); + Value predicate = mapping.lookupOrDefault(yieldOp.getValue()); + Value combined = rewriter.create(loc, loop.getRegionIterArg(0), predicate); + rewriter.create(loc, combined); + + rewriter.replaceOp(op, loop.getResults()); + return success(); +} + +/// Rewrite `bool.forall` over an array-of-POD to iterate over the split leaf arrays directly. +class SplitPodArrayForAllOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(boolean::ForAllOp op) { return !splittablePodArray(op.getSort().getType()); } + + LogicalResult matchAndRewrite( + boolean::ForAllOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + return rewriteSplitPodArrayQuantifier( + op, adaptor.getSort(), rewriter, /*initialValue=*/true + ); + } +}; + +/// Rewrite `bool.exists` over an array-of-POD to iterate over the split leaf arrays directly. +class SplitPodArrayExistsOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(boolean::ExistsOp op) { return !splittablePodArray(op.getSort().getType()); } + + LogicalResult matchAndRewrite( + boolean::ExistsOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + return rewriteSplitPodArrayQuantifier( + op, adaptor.getSort(), rewriter, /*initialValue=*/false + ); + } +}; + +/// Rewrite `array.extract` of an array-of-POD subarray into one extract per parallel leaf array. +class SplitPodArrayExtractArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ExtractArrayOp op) { return !splittablePodArray(op.getResult().getType()); } + + LogicalResult matchAndRewrite( + ExtractArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + SmallVector splitResultTypes; + splitPodArrayTypeTo(op.getResult().getType(), splitResultTypes); + if (splitResultTypes.empty()) { + ArrayType resultTy = llvm::cast(op.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, getZeroLeafPodArrayShapeCarrierType(resultTy), + getSingleConvertedValue(adaptor.getArrRef()), flattenConvertedValues(adaptor.getIndices()) + ); + return success(); + } + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + SmallVector replacements; + replacements.reserve(splitResultTypes.size()); + for (auto [splitArrRange, splitResultType] : + llvm::zip_equal(adaptor.getArrRef(), splitResultTypes)) { + replacements.push_back(rewriter.create( + op.getLoc(), llvm::cast(splitResultType), + getSingleConvertedValue(splitArrRange), indices + )); + } + + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + +/// Rewrite `array.insert` of an array-of-POD subarray into one insert per parallel leaf array. +class SplitPodArrayInsertArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(InsertArrayOp op) { return !splittablePodArray(op.getRvalue().getType()); } + + LogicalResult matchAndRewrite( + InsertArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + if (hasZeroLeafPodArraySplit(llvm::cast(op.getRvalue().getType()))) { + rewriter.create( + op.getLoc(), getSingleConvertedValue(adaptor.getArrRef()), + flattenConvertedValues(adaptor.getIndices()), getSingleConvertedValue(adaptor.getRvalue()) + ); + rewriter.eraseOp(op); + return success(); + } + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + for (auto [splitArrRange, splitRvalueRange] : + llvm::zip_equal(adaptor.getArrRef(), adaptor.getRvalue())) { + rewriter.create( + op.getLoc(), getSingleConvertedValue(splitArrRange), indices, + getSingleConvertedValue(splitRvalueRange) + ); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite a write to a split array-of-POD struct member into writes to each parallel array member. +class SplitPodArrayInMemberWriteOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + +public: + SplitPodArrayInMemberWriteOp( + const TypeConverter &converter, MLIRContext *ctx, SymbolTableCollection &symTables, + const MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(converter, ctx), tables(symTables), + repMapRef(memberRepMap) {} + + static bool legal(MemberWriteOp op) { return !splittablePodArray(op.getVal().getType()); } + + LogicalResult matchAndRewrite( + MemberWriteOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + ArrayType arrTy = llvm::cast(op.getVal().getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + const MemberInfo &carrierMember = idToMember.at(RecordChain()); + rewriter.create( + op.getLoc(), getSingleConvertedValue(adaptor.getComponent()), + FlatSymbolRefAttr::get(carrierMember.first), getSingleConvertedValue(adaptor.getVal()) + ); + rewriter.eraseOp(op); + return success(); + } + + for (auto [id, splitValRange] : llvm::zip_equal(splitIds, adaptor.getVal())) { + const MemberInfo &newMember = idToMember.at(id); + rewriter.create( + op.getLoc(), getSingleConvertedValue(adaptor.getComponent()), + FlatSymbolRefAttr::get(newMember.first), getSingleConvertedValue(splitValRange) + ); + } rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite a read from a split array-of-POD struct member into reads of each parallel array member. +class SplitPodArrayInMemberReadOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + +public: + SplitPodArrayInMemberReadOp( + const TypeConverter &converter, MLIRContext *ctx, SymbolTableCollection &symTables, + const MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(converter, ctx), tables(symTables), + repMapRef(memberRepMap) {} + + static bool legal(MemberReadOp op) { return !splittablePodArray(op.getResult().getType()); } + + LogicalResult matchAndRewrite( + MemberReadOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + ArrayType arrTy = llvm::cast(op.getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + SmallVector mapOperands; + std::optional numDimsPerMap; + auto mapOperandsOld = adaptor.getMapOperands(); + if (!mapOperandsOld.empty()) { + assert( + mapOperandsOld.size() == 1 && + "member.readm should have at most one affine-map operand group" + ); + mapOperands = flattenConvertedValues(mapOperandsOld.front()); + + ArrayRef numDimsPerMapOld = op.getNumDimsPerMap(); + if (!numDimsPerMapOld.empty()) { + assert( + numDimsPerMapOld.size() == 1 && + "member.readm should have one numDims entry per affine-map group" + ); + numDimsPerMap = numDimsPerMapOld.front(); + } + } + if (splitTypes.empty()) { + const MemberInfo &carrierMember = idToMember.at(RecordChain()); + Value carrierRead = rewriter.create( + op.getLoc(), carrierMember.second, getSingleConvertedValue(adaptor.getComponent()), + carrierMember.first, op.getTableOffset().value_or(nullptr), mapOperands, numDimsPerMap + ); + rewriter.replaceOpWithMultiple(op, {ValueRange {carrierRead}}); + return success(); + } + SmallVector replacements; + replacements.reserve(splitIds.size()); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + const MemberInfo &newMember = idToMember.at(id); + replacements.push_back(rewriter.create( + op.getLoc(), splitType, getSingleConvertedValue(adaptor.getComponent()), newMember.first, + op.getTableOffset().value_or(nullptr), mapOperands, numDimsPerMap + )); + } + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); } }; -/// Replace `PodType` struct members with scalar members. +/// Split arrays-of-POD into parallel arrays before direct pod scalarization. static LogicalResult -step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) { +step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { MLIRContext *ctx = modOp.getContext(); + PodArrayTypeConverter typeConverter; RewritePatternSet patterns(ctx); - - patterns.add(ctx, symTables, memberRepMap); + patterns.add< + SplitPodArrayNonDetOp, SplitPodArrayCreateArrayOp, SplitPodArrayReadArrayOp, + SplitPodArrayWriteArrayOp, SplitPodArrayExtractArrayOp, SplitPodArrayInsertArrayOp, + SplitPodArrayInFuncDefOp, SplitPodArrayInUnifiableCastOp, SplitPodArrayInReturnOp, + SplitPodArrayInCallOp, SplitPodArrayInEmitEqualityOp, SplitPodArrayInEmitContainmentOp, + SplitPodArrayLengthOp, SplitPodArrayForAllOp, SplitPodArrayExistsOp>(typeConverter, ctx); + patterns.add( + typeConverter, ctx, symTables, memberRepMap + ); ConversionTarget target(*ctx); baseTargetSetup(target); - target.addDynamicallyLegalOp(SplitPodInMemberDefOp::legal); + target.addLegalOp(); + target.addDynamicallyLegalOp(SplitPodArrayNonDetOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayCreateArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayReadArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayWriteArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayExtractArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInsertArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInFuncDefOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInUnifiableCastOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInReturnOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInCallOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInEmitEqualityOp::legal); + target.addDynamicallyLegalOp( + SplitPodArrayInEmitContainmentOp::legal + ); + target.addDynamicallyLegalOp(SplitPodArrayLengthOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayForAllOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayExistsOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInMemberWriteOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInMemberReadOp::legal); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); - LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split pod-type members\n";); + LLVM_DEBUG(llvm::dbgs() << "Begin step 2: split arrays with POD element type\n";); return applyFullConversion(modOp, target, std::move(patterns)); } @@ -441,6 +2659,87 @@ class SplitInitFromNewPodOp : public OpConversionPattern { } }; +/// Rewrite `array.new` when explicit elements are PODs or flattened leaf arrays. +/// +/// This occurs after the array-of-POD stage has already converted the result type away from +/// `!array.type<... x !pod.type<...>>`, but before the POD operands themselves have been fully +/// scalarized. Rebuild the destination array explicitly so leaf arrays become subarray inserts +/// rather than invalid inline operands to the flattened `array.new`. +class SplitPodElementCreateArrayOp : public OpConversionPattern { + const VirtualPodValueMap &virtualPods; + +public: + SplitPodElementCreateArrayOp(MLIRContext *ctx, const VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + static bool legal(CreateArrayOp op) { + return !llvm::any_of(op.getElements().getTypes(), [](Type type) { + return splittablePod(type) || llvm::isa(type); + }); + } + + LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); } + + void + rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector leafElements; + leafElements.reserve(adaptor.getElements().size()); + + Type leafType; + for (Value element : adaptor.getElements()) { + SmallVector flattenedValues; + if (splittablePod(element.getType())) { + processInputOperand( + op.getLoc(), element, flattenedValues, rewriter, op.getOperation(), &virtualPods + ); + } else { + flattenedValues.push_back(element); + } + + assert( + flattenedValues.size() == 1 && + "array.new elements should already have been split to a single flattened leaf" + ); + if (!leafType) { + leafType = flattenedValues.front().getType(); + } else { + assert( + leafType == flattenedValues.front().getType() && "array.new elements must stay uniform" + ); + } + leafElements.push_back(flattenedValues.front()); + } + + size_t leafRank = 0; + if (auto leafArrTy = llvm::dyn_cast_if_present(leafType)) { + leafRank = leafArrTy.getDimensionSizes().size(); + } + ArrayType arrTy = op.getType(); + assert( + arrTy.getDimensionSizes().size() >= leafRank && "flattened leaf rank exceeds array rank" + ); + size_t outerRank = arrTy.getDimensionSizes().size() - leafRank; + assert(outerRank > 0 && "array.new elements must populate at least one outer array dimension"); + + ArrayType outerIndexTy = + ArrayType::get(arrTy.getElementType(), arrTy.getDimensionSizes().take_front(outerRank)); + auto elementIndices = outerIndexTy.getSubelementIndices(); + assert( + elementIndices && "array.new with explicit POD elements requires static outer dimensions" + ); + assert( + elementIndices->size() == leafElements.size() && + "array.new element count must match the outer array cardinality" + ); + + Value rebuiltArray = createWritableArrayValue(rewriter, op.getLoc(), arrTy); + for (auto [index, leafValue] : llvm::zip_equal(*elementIndices, leafElements)) { + genArrayWrite(rewriter, op.getLoc(), rebuiltArray, index, leafValue); + } + rewriter.replaceOp(op, rebuiltArray); + } +}; + /// Rewrite pod-typed function signatures to pass one scalar per POD record instead. /// /// Each pod argument is expanded into one scalar argument per record, and each pod result is @@ -449,8 +2748,11 @@ class SplitInitFromNewPodOp : public OpConversionPattern { /// rest of the function can continue to use POD values until later cleanup passes scalarize those /// local temporaries away. class SplitPodInFuncDefOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + public: - using OpConversionPattern::OpConversionPattern; + SplitPodInFuncDefOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} inline static bool legal(FuncDefOp op) { return !containsSplittablePodType(op.getArgumentTypes()) && @@ -465,6 +2767,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { SmallVector originalInputIdxToSize, originalResultIdxToSize; SplitFunctionNameInfo inputNameInfo; SplitFunctionNameInfo resultNameInfo; + VirtualPodValueMap &virtualPods; protected: SmallVector convertInputs(ArrayRef origTypes) override { @@ -500,18 +2803,21 @@ class SplitPodInFuncDefOp : public OpConversionPattern { Value oldV = entryBlock.getArgument(i); if (PodType pt = splittablePod(oldV.getType())) { Location loc = oldV.getLoc(); - // Generate `NewPodOp` and replace uses of the argument with it. - auto newPod = rewriter.create(loc, pt); - rewriter.replaceAllUsesWith(oldV, newPod); - // Remove the argument from the block + VirtualPodLeafMap leafValues; + SmallVector recordChain; + unsigned nextArgIdx = i + 1; + forEachPodLeaf(pt, recordChain, [&](const RecordChain &id, Type leafType) { + BlockArgument newArg = entryBlock.insertArgument(nextArgIdx, leafType, loc); + leafValues[id] = newArg; + ++nextArgIdx; + }); + + Value virtualPod = createVirtualPodPlaceholder(rewriter, loc, pt, leafValues); + rewriter.replaceAllUsesWith(oldV, virtualPod); entryBlock.eraseArgument(i); - // For all indices in the PodType (i.e., the element count), generate a new - // block argument and a write of that argument to the new pod. - for (RecordAttr record : pt.getRecords()) { - BlockArgument newArg = entryBlock.insertArgument(i, record.getType(), loc); - genWrite(loc, newPod, record.getName(), newArg, rewriter); - ++i; - } + + i += leafValues.size(); + virtualPods[virtualPod] = std::move(leafValues); } else { ++i; } @@ -519,7 +2825,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { } public: - Impl(FuncDefOp op) { + Impl(FuncDefOp op, VirtualPodValueMap &virtualPodMap) : virtualPods(virtualPodMap) { inputNameInfo = collectSplitFunctionNameInfo(op.getArgumentTypes(), [&op](unsigned i) { return op.getArgNameAttr(i); }, getSplitRecordNameSuffixes); @@ -530,7 +2836,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { ); } }; - Impl(op).convert(op, rewriter); + Impl(op, virtualPods).convert(op, rewriter); } }; @@ -540,278 +2846,633 @@ class SplitPodInFuncDefOp : public OpConversionPattern { /// are returned as one SSA value per record, using local `pod.read` operations to extract the /// scalar pieces immediately before the return. class SplitPodInReturnOp : public OpConversionPattern { + const VirtualPodValueMap &virtualPods; + public: - using OpConversionPattern::OpConversionPattern; + SplitPodInReturnOp(MLIRContext *ctx, const VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + inline static bool legal(ReturnOp op) { + return !containsSplittablePodType(op.getOperands().getTypes()); + } + + LogicalResult match(ReturnOp op) const override { return failure(legal(op)); } + + void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + processInputOperands( + adaptor.getOperands(), op.getOperandsMutable(), op, rewriter, &virtualPods + ); + } +}; + +/// Rebuild a call with split scalar results, then reconstruct POD-typed results locally. +static CallOp newCallOpWithSplitResults( + CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, + VirtualPodValueMap &virtualPods +) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(oldCall); + + Operation::result_range oldResults = oldCall.getResults(); + CallOp newCall = createCallPreservingInstantiationOperands( + oldCall.getLoc(), splitPodType(oldResults.getTypes()), oldCall, adaptor.getMapOperands(), + adaptor.getArgOperands(), rewriter + ); + + auto newResults = newCall.getResults().begin(); + for (Value oldVal : oldResults) { + if (PodType pt = splittablePod(oldVal.getType())) { + Location loc = oldVal.getLoc(); + VirtualPodLeafMap leafValues; + SmallVector recordChain; + forEachPodLeaf(pt, recordChain, [&leafValues, &newResults](const RecordChain &id, Type) { + leafValues[id] = *newResults; + ++newResults; + }); + Value virtualPod = createVirtualPodPlaceholder(rewriter, loc, pt, leafValues); + virtualPods[virtualPod] = std::move(leafValues); + rewriter.replaceAllUsesWith(oldVal, virtualPod); + } else { + rewriter.replaceAllUsesWith(oldVal, *newResults); + newResults++; + } + } + // erase the original CallOp + rewriter.eraseOp(oldCall); + + return newCall; +} + +/// Rewrite calls whose arguments or results contain PODs to use flattened scalar signatures. +/// +/// POD arguments are decomposed into scalar record operands before the new call is formed. POD +/// results are reconstructed locally after the call with `pod.new` plus `pod.write`, preserving +/// the original POD-typed uses in the caller until later optimization passes remove the temporary +/// POD allocations. +class SplitPodInCallOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + +public: + SplitPodInCallOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + inline static bool legal(CallOp op) { + return !containsSplittablePodType(op.getArgOperands().getTypes()) && + !containsSplittablePodType(op.getResultTypes()); + } + + LogicalResult match(CallOp op) const override { return failure(legal(op)); } + + void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Create new CallOp with split results first so, then process its inputs to split types + CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter, virtualPods); + processInputOperands( + newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter, &virtualPods + ); + } +}; + +/// Rewrite a write to a pod-typed struct member into writes to the corresponding scalar leaves. +class SplitPodInMemberWriteOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + const VirtualPodValueMap &virtualPods; + +public: + SplitPodInMemberWriteOp( + MLIRContext *ctx, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap, + const VirtualPodValueMap &virtualPodMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap), + virtualPods(virtualPodMap) {} + + static bool legal(MemberWriteOp op) { return !containsSplittablePodType(op.getVal().getType()); } + + LogicalResult match(MemberWriteOp op) const override { return failure(legal(op)); } + + void + rewrite(MemberWriteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + const VirtualPodLeafMap *virtualLeafValues = + !hasEarlierWriteToPodInBlock(op.getOperation(), adaptor.getVal()) + ? lookupVirtualPodLeafMap(adaptor.getVal(), virtualPods) + : nullptr; + + for (const auto &[id, newMember] : idToMember) { + Value scalarValue = virtualLeafValues + ? virtualLeafValues->at(id) + : genReadAlongPath(rewriter, op.getLoc(), adaptor.getVal(), id); + rewriter.create( + op.getLoc(), adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarValue + ); + } + rewriter.eraseOp(op); + } +}; + +/// Rewrite a read from a pod-typed struct member into reads from the corresponding scalar leaves. +class SplitPodInMemberReadOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + VirtualPodValueMap &virtualPods; + +public: + SplitPodInMemberReadOp( + MLIRContext *ctx, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap, + VirtualPodValueMap &virtualPodMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap), + virtualPods(virtualPodMap) {} + + static bool legal(MemberReadOp op) { + return !containsSplittablePodType(op.getResult().getType()); + } + + LogicalResult match(MemberReadOp op) const override { return failure(legal(op)); } + + void + rewrite(MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + + VirtualPodLeafMap leafValues; + for (const auto &[id, newMember] : idToMember) { + leafValues[id] = rewriter.create( + op.getLoc(), newMember.second, adaptor.getComponent(), newMember.first + ); + } + + PodType podTy = llvm::cast(op.getType()); + Value virtualPod = createVirtualPodPlaceholder(rewriter, op.getLoc(), podTy, leafValues); + virtualPods[virtualPod] = std::move(leafValues); + rewriter.replaceOp(op, virtualPod); + } +}; + +/// Collect precise split leaf arrays from a value re-materialized as an aggregate array-of-POD. +/// +/// This recognizes the temporary aggregate form produced by dialect conversion casts and unwraps +/// it back into the parallel split arrays expected by the late pod-array read resolvers. +static bool tryCollectMaterializedSplitPodArrayLeafValues( + Value arrayValue, ArrayType arrTy, ArrayRef splitTypes, SmallVectorImpl &leafArrays +) { + auto cast = arrayValue.getDefiningOp(); + if (!cast || cast->getNumResults() != 1 || cast.getResult(0).getType() != arrTy || + cast->getNumOperands() != splitTypes.size()) { + return false; + } + + for (auto [operand, splitType] : llvm::zip_equal(cast.getOperands(), splitTypes)) { + if (operand.getType() != splitType) { + return false; + } + leafArrays.push_back(operand); + } + return true; +} + +/// Collect precise split leaf arrays for an array-of-POD value backed by a direct `pod.read`. +/// +/// This first consults virtual POD leaf storage and, if unavailable, falls back to forwarding +/// through a dominating same-record `pod.write` whose value was previously materialized as split +/// arrays. +static bool tryCollectReadPodSplitPodArrayLeafValues( + ReadPodOp readOp, ArrayType arrTy, ArrayRef splitIds, ArrayRef splitTypes, + const VirtualPodValueMap &virtualPods, SmallVectorImpl &leafArrays +) { + if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { + return tryCollectMaterializedSplitPodArrayLeafValues( + writeOp.getValue(), arrTy, splitTypes, leafArrays + ); + } + + if (!hasEarlierWriteInBlock(readOp)) { + if (const VirtualPodLeafMap *podLeafValues = + lookupVirtualPodLeafMap(readOp.getPodRef(), virtualPods)) { + leafArrays.reserve(splitIds.size()); + for (const RecordChain &id : splitIds) { + SmallVector fullChain {readOp.getRecordNameAttr()}; + llvm::append_range(fullChain, id.nameList); + auto it = podLeafValues->find(RecordChain(fullChain)); + if (it == podLeafValues->end() || + !typesUnify(it->second.getType(), getFlattenedTypeAlongPath(arrTy, id.nameList))) { + return false; + } + leafArrays.push_back(it->second); + } + return true; + } + } + + return false; +} + +/// Materialize or recover split leaf arrays for a dynamic array-of-POD produced by `pod.read`. +static bool resolveReadPodSplitPodArrayLeafValues( + ReadPodOp readOp, ArrayType arrTy, ArrayRef splitIds, ArrayRef splitTypes, + const VirtualPodValueMap &virtualPods, DeferredPodArrayLeafMap &deferredPodArrays, Location loc, + OpBuilder &bldr, SmallVectorImpl &leafArrays +) { + if (tryCollectReadPodSplitPodArrayLeafValues( + readOp, arrTy, splitIds, splitTypes, virtualPods, leafArrays + )) { + return true; + } - inline static bool legal(ReturnOp op) { - return !containsSplittablePodType(op.getOperands().getTypes()); + if (!isFreshUnwrittenPodRead(readOp)) { + return false; } - LogicalResult match(ReturnOp op) const override { return failure(legal(op)); } + // Reuse one synthetic split-array backing per deferred field read so repeated users of the same + // aggregate value continue to observe the same unwritten leaf storage. + auto [it, inserted] = deferredPodArrays.try_emplace(readOp.getResult()); + leafArrays.assign(it->second.begin(), it->second.end()); + if (inserted) { + OpBuilder::InsertionGuard guard(bldr); + bldr.setInsertionPointAfter(readOp); + leafArrays.reserve(splitTypes.size()); + for (Type splitType : splitTypes) { + leafArrays.push_back(createWritableArrayValue(bldr, loc, llvm::cast(splitType))); + } + it->second.assign(leafArrays.begin(), leafArrays.end()); + } else { + assert( + leafArrays.size() == splitTypes.size() && + "cached split POD arrays must match the rewritten read arity" + ); + } - void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter); + return true; +} + +/// Erase a resolved deferred field-read chain once both the read and its placeholder pod vanish. +static void eraseDeadDeferredFieldReadChain(ReadPodOp readOp, PatternRewriter &rewriter) { + if (!readOp.getResult().use_empty()) { + return; } -}; -/// Rebuild a call with split scalar results, then reconstruct POD-typed results locally. -static CallOp newCallOpWithSplitResults( - CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter + Value podRef = readOp.getPodRef(); + rewriter.eraseOp(readOp); + if (podRef.use_empty()) { + if (auto cast = podRef.getDefiningOp()) { + if (cast->getNumResults() == 1 && cast.getResult(0) == podRef) { + rewriter.eraseOp(cast); + } + } + } +} + +/// Return `true` iff `op` is a deferred split placeholder for one array-of-POD aggregate value. +static bool getDeferredSplitPodArrayCastInfo( + UnrealizedConversionCastOp op, ArrayType &arrTy, SmallVector &splitIds, + SmallVectorImpl &splitTypes ) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(oldCall); + if (op->getNumOperands() != 1) { + return false; + } - Operation::result_range oldResults = oldCall.getResults(); - CallOp newCall = createCallPreservingInstantiationOperands( - oldCall.getLoc(), splitPodType(oldResults.getTypes()), oldCall, adaptor.getMapOperands(), - adaptor.getArgOperands(), rewriter - ); + arrTy = splittablePodArray(op.getOperand(0).getType()); + if (!arrTy) { + return false; + } - auto newResults = newCall.getResults().begin(); - for (Value oldVal : oldResults) { - if (PodType pt = splittablePod(oldVal.getType())) { - Location loc = oldVal.getLoc(); - // Generate `NewPodOp` and replace uses of the result with it. - auto newPod = rewriter.create(loc, pt); - rewriter.replaceAllUsesWith(oldVal, newPod); + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (op->getNumResults() != splitTypes.size()) { + return false; + } - // For each record in the PodType, write the next result from the new CallOp to the new pod. - for (RecordAttr record : pt.getRecords()) { - genWrite(loc, newPod, record.getName(), *newResults, rewriter); - newResults++; - } - } else { - rewriter.replaceAllUsesWith(oldVal, *newResults); - newResults++; + for (auto [result, splitType] : llvm::zip_equal(op.getResults(), splitTypes)) { + if (result.getType() != splitType) { + return false; } } - // erase the original CallOp - rewriter.eraseOp(oldCall); - - return newCall; + return true; } -/// Rewrite calls whose arguments or results contain PODs to use flattened scalar signatures. +/// Resolve deferred `array.read` from `pod.read`-produced array-of-POD values. /// -/// POD arguments are decomposed into scalar record operands before the new call is formed. POD -/// results are reconstructed locally after the call with `pod.new` plus `pod.write`, preserving -/// the original POD-typed uses in the caller until later optimization passes remove the temporary -/// POD allocations. -class SplitPodInCallOp : public OpConversionPattern { +/// When step 2 defers a read because the array-of-POD came from a POD record, this pattern +/// reconstructs the per-leaf split arrays, performs the array read on each leaf array, and then +/// rebuilds the element POD virtually instead of materializing the whole aggregate array first. +class ResolvePodReadBackedArrayReadOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + DeferredPodArrayLeafMap &deferredPodArrays; + public: - using OpConversionPattern::OpConversionPattern; + ResolvePodReadBackedArrayReadOp( + MLIRContext *ctx, VirtualPodValueMap &virtualPodMap, + DeferredPodArrayLeafMap &deferredPodArrayMap + ) + : OpConversionPattern(ctx), virtualPods(virtualPodMap), + deferredPodArrays(deferredPodArrayMap) {} - inline static bool legal(CallOp op) { - return !containsSplittablePodType(op.getArgOperands().getTypes()) && - !containsSplittablePodType(op.getResultTypes()); + static bool canResolve(ReadArrayOp op, const VirtualPodValueMap &virtualPods) { + if (!shouldDeferPodArrayReadToStep3(op)) { + return false; + } + + ArrayType arrTy = op.getArrRefType(); + auto fieldRead = llvm::cast(op.getArrRef().getDefiningOp()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector ignoredLeafArrays; + return tryCollectReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, ignoredLeafArrays + ) || + isFreshUnwrittenPodRead(fieldRead); } - LogicalResult match(CallOp op) const override { return failure(legal(op)); } + LogicalResult matchAndRewrite( + ReadArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + auto fieldRead = op.getArrRef().getDefiningOp(); + if (!fieldRead) { + return failure(); + } - void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Create new CallOp with split results first so, then process its inputs to split types - CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter); - processInputOperands( - newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter - ); + ArrayType arrTy = op.getArrRefType(); + PodType podTy = llvm::cast(arrTy.getElementType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector splitLeafArrays; + if (!resolveReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, deferredPodArrays, op.getLoc(), + rewriter, splitLeafArrays + )) { + return failure(); + } + + SmallVector indices(adaptor.getIndices().begin(), adaptor.getIndices().end()); + VirtualPodLeafMap leafValues; + for (auto [id, leafArray] : llvm::zip_equal(splitIds, splitLeafArrays)) { + leafValues[id] = genArrayRead(rewriter, op.getLoc(), leafArray, indices); + } + + Value virtualPod = createVirtualPodPlaceholder(rewriter, op.getLoc(), podTy, leafValues); + virtualPods[virtualPod] = std::move(leafValues); + rewriter.replaceOp(op, virtualPod); + eraseDeadDeferredFieldReadChain(fieldRead, rewriter); + return success(); } }; -/// Read a nested POD leaf by following each record name in `recordChain`. -static Value -genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &rewriter) { - Value value = podRef; - for (StringAttr attr : recordChain.nameList) { - value = genRead(loc, value, attr, rewriter); +/// Resolve deferred split-array placeholders created while flattening direct POD operands. +/// +/// Step 2 may need one specific split leaf array from a dynamic array-of-POD field before step 3 +/// has converted the surrounding POD value into virtual leaf storage. In that case +/// `genReadAlongPath` leaves behind a `builtin.unrealized_conversion_cast` from the aggregate field +/// read to all split leaf arrays, and this pattern resolves that placeholder once the backing leaf +/// arrays become available. +class ResolveDeferredSplitPodArrayCastOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + DeferredPodArrayLeafMap &deferredPodArrays; + +public: + ResolveDeferredSplitPodArrayCastOp( + MLIRContext *ctx, VirtualPodValueMap &virtualPodMap, + DeferredPodArrayLeafMap &deferredPodArrayMap + ) + : OpConversionPattern(ctx), virtualPods(virtualPodMap), + deferredPodArrays(deferredPodArrayMap) {} + + static bool canResolve(UnrealizedConversionCastOp op, const VirtualPodValueMap &virtualPods) { + ArrayType arrTy; + SmallVector splitIds; + SmallVector splitTypes; + if (!getDeferredSplitPodArrayCastInfo(op, arrTy, splitIds, splitTypes)) { + return false; + } + + ReadPodOp fieldRead = peelUnifiableCasts(op.getOperand(0)).getDefiningOp(); + if (!fieldRead) { + return false; + } + + SmallVector ignoredLeafArrays; + return tryCollectReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, ignoredLeafArrays + ) || + isFreshUnwrittenPodRead(fieldRead); } - return value; -} -/// State used while rebuilding a POD from flattened struct-member leaves. -struct RebuildPodReadState { - NewPodOp pod; - DenseMap leafValues; -}; + LogicalResult matchAndRewrite( + UnrealizedConversionCastOp op, OpAdaptor, ConversionPatternRewriter &rewriter + ) const override { + ArrayType arrTy; + SmallVector splitIds; + SmallVector splitTypes; + if (!getDeferredSplitPodArrayCastInfo(op, arrTy, splitIds, splitTypes)) { + return failure(); + } -/// Reconstruct a POD record from the leaf values collected while splitting `struct.readm`. -static Value rebuildFlattenedPodRecord( - Location loc, Type recordType, SmallVectorImpl &recordChain, - const DenseMap &leafValues, ConversionPatternRewriter &rewriter -) { - if (PodType nestedPodTy = dyn_cast(recordType)) { - NewPodOp nestedPod = rewriter.create(loc, nestedPodTy); - for (RecordAttr record : nestedPodTy.getRecords()) { - recordChain.push_back(record.getName()); - Value recordValue = - rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, rewriter); - genWrite(loc, nestedPod, record.getName(), recordValue, rewriter); - recordChain.pop_back(); + ReadPodOp fieldRead = peelUnifiableCasts(op.getOperand(0)).getDefiningOp(); + if (!fieldRead) { + return failure(); } - return nestedPod; + + SmallVector splitLeafArrays; + if (!resolveReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, deferredPodArrays, op.getLoc(), + rewriter, splitLeafArrays + )) { + return failure(); + } + + rewriter.replaceOp(op, splitLeafArrays); + eraseDeadDeferredFieldReadChain(fieldRead, rewriter); + return success(); } +}; - auto it = leafValues.find(RecordChain(recordChain)); - assert(it != leafValues.end() && "missing flattened POD leaf value"); - return it->second; -} +/// Update virtual POD leaf storage in response to `pod.write` without materializing the aggregate. +class ResolveVirtualPodWriteOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; -/// Rewrite a write to a pod-typed struct member into writes to the corresponding scalar leaves. -class SplitPodInMemberWriteOp : public SplitAggregateInMemberRefOp< - SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain> { public: - using SplitAggregateInMemberRefOp< - SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain>::SplitAggregateInMemberRefOp; - - static bool legal(MemberWriteOp op) { return !containsSplittablePodType(op.getVal().getType()); } + ResolveVirtualPodWriteOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} - static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) { return nullptr; } + LogicalResult matchAndRewrite( + WritePodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + auto it = virtualPods.find(adaptor.getPodRef()); + if (it == virtualPods.end()) { + return failure(); + } - static void forId( - Location loc, void *&, RecordChain id, MemberInfo newMember, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter - ) { - Value scalarRead = genReadAlongPath(loc, adaptor.getVal(), id, rewriter); - rewriter.create( - loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead + Type recordType = + llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); + assert(recordType && "record must exist in POD type"); + updateVirtualPodRecordLeafValues( + op.getLoc(), op.getRecordNameAttr(), recordType, adaptor.getValue(), virtualPods, rewriter, + it->second ); + rewriter.eraseOp(op); + return success(); } }; -/// Rewrite a read from a pod-typed struct member into reads from the corresponding scalar leaves. -class SplitPodInMemberReadOp - : public SplitAggregateInMemberRefOp< - SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState, RecordChain> { +/// Resolve reads from a virtual POD placeholder without materializing the whole aggregate. +/// +/// This pattern answers `pod.read` directly from virtual leaf storage, rebuilding nested POD +/// subrecords on demand and casting scalar leaves back to the precise record type when needed. +class ResolveVirtualPodReadOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + public: - using SplitAggregateInMemberRefOp< - SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState, - RecordChain>::SplitAggregateInMemberRefOp; + ResolveVirtualPodReadOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} - static bool legal(MemberReadOp op) { - return !containsSplittablePodType(op.getResult().getType()); - } + LogicalResult matchAndRewrite( + ReadPodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (hasEarlierWriteInBlock(op) || findNearestForwardableWriteInBlock(op)) { + return failure(); + } - static RebuildPodReadState genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) { - RebuildPodReadState state; - state.pod = rewriter.create(op.getLoc(), llvm::cast(op.getType())); - rewriter.replaceAllUsesWith(op, state.pod); - return state; - } + const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(adaptor.getPodRef(), virtualPods); + if (!leafValues) { + return failure(); + } - static void forId( - Location loc, RebuildPodReadState &state, RecordChain id, MemberInfo newMember, - OpAdaptor adaptor, ConversionPatternRewriter &rewriter - ) { - Value scalarRead = rewriter.create( - loc, newMember.second, adaptor.getComponent(), newMember.first - ); - state.leafValues[id] = scalarRead; - } + SmallVector prefix {op.getRecordNameAttr()}; + Type recordType = + llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); + assert(recordType && "record must exist in POD type"); + + if (PodType nestedPodTy = llvm::dyn_cast(recordType)) { + VirtualPodLeafMap nestedLeafValues; + SmallVector nestedRecordChain; + forEachPodLeaf(nestedPodTy, nestedRecordChain, [&](RecordChain id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + nestedLeafValues[id] = leafValues->at(RecordChain(fullChain)); + }); + Value virtualPod = + createVirtualPodPlaceholder(rewriter, op.getLoc(), nestedPodTy, nestedLeafValues); + virtualPods[virtualPod] = std::move(nestedLeafValues); + rewriter.replaceOp(op, virtualPod); + return success(); + } - static void finalize( - MemberReadOp op, RebuildPodReadState &state, OpAdaptor, ConversionPatternRewriter &rewriter - ) { - auto podTy = llvm::cast(op.getType()); - SmallVector recordChain; - for (RecordAttr record : podTy.getRecords()) { - recordChain.push_back(record.getName()); - Value recordValue = rebuildFlattenedPodRecord( - op.getLoc(), record.getType(), recordChain, state.leafValues, rewriter - ); - genWrite(op.getLoc(), state.pod, record.getName(), recordValue, rewriter); - recordChain.pop_back(); + if (splittablePodArray(recordType)) { + return failure(); } + + rewriter.replaceOp( + op, castValueToTypeIfNeeded( + rewriter, op.getLoc(), leafValues->at(RecordChain(prefix)), recordType + ) + ); + return success(); } }; /// Special handling to split pods in struct member refs and function signatures and desugar /// initializations on pod.new into pod writes. static LogicalResult -step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { +step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { MLIRContext *ctx = modOp.getContext(); + VirtualPodValueMap virtualPods; + DeferredPodArrayLeafMap deferredPodArrays; RewritePatternSet patterns(ctx); - patterns.add< - // clang-format off - SplitInitFromNewPodOp, - SplitPodInFuncDefOp, - SplitPodInReturnOp, - SplitPodInCallOp - // clang-format on - >(ctx); - - patterns.add< - // clang-format off - SplitPodInMemberWriteOp, - SplitPodInMemberReadOp - // clang-format on - >(ctx, symTables, memberRepMap); + patterns.add(ctx); + patterns.add(ctx, virtualPods); + patterns.add(ctx, virtualPods); + patterns.add( + ctx, symTables, memberRepMap, virtualPods + ); + patterns.add(ctx, virtualPods, deferredPodArrays); + patterns.add(ctx, virtualPods, deferredPodArrays); + patterns.add(ctx, virtualPods); ConversionTarget target(*ctx); baseTargetSetup(target); target.addDynamicallyLegalOp(SplitInitFromNewPodOp::legal); + target.addDynamicallyLegalOp(SplitPodElementCreateArrayOp::legal); target.addDynamicallyLegalOp(SplitPodInFuncDefOp::legal); target.addDynamicallyLegalOp(SplitPodInReturnOp::legal); target.addDynamicallyLegalOp(SplitPodInCallOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberReadOp::legal); - - LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other pod ops\n";); - return applyFullConversion(modOp, target, std::move(patterns)); -} - -/// Return whether the given read/write access targets the same POD record. -inline static bool isSamePodRecord(ReadPodOp readOp, Value podRef, StringAttr recordName) { - return readOp.getPodRef() == podRef && readOp.getRecordNameAttr() == recordName; -} - -/// Return whether the given read/write access targets the same POD record. -inline static bool isSamePodRecord(WritePodOp writeOp, Value podRef, StringAttr recordName) { - return writeOp.getPodRef() == podRef && writeOp.getRecordNameAttr() == recordName; -} - -/// Return whether `op` contains a nested write to `podRef.recordName`. -static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recordName) { - return walkContainsMatch(op, [&](WritePodOp writeOp) { - return writeOp.getOperation() != &op && isSamePodRecord(writeOp, podRef, recordName); + target.addDynamicallyLegalOp([&virtualPods](WritePodOp op) { + return !lookupVirtualPodLeafMap(op.getPodRef(), virtualPods); }); -} - -/// Return whether `op` contains any read from `podRef.recordName`. -static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) { - return walkContainsMatch(op, [&](ReadPodOp readOp) { - return isSamePodRecord(readOp, podRef, recordName); + target.addDynamicallyLegalOp([&virtualPods](ReadArrayOp op) { + return !ResolvePodReadBackedArrayReadOp::canResolve(op, virtualPods); }); -} - -/// Return whether `op` or any nested operation uses `value` as an operand. -static bool hasValueUse(Operation &op, Value value) { - return walkContainsMatch(op, [&value](Operation *nestedOp) { - return llvm::is_contained(nestedOp->getOperands(), value); + target.addDynamicallyLegalOp( + [&virtualPods](UnrealizedConversionCastOp op) { + return !ResolveDeferredSplitPodArrayCastOp::canResolve(op, virtualPods); + } + ); + target.addDynamicallyLegalOp([&virtualPods](ReadPodOp op) { + return !canResolveVirtualPodRead(op, virtualPods); }); -} -/// Return whether the read is preceded by a write to the same pod record within its block. -static bool hasEarlierWriteInBlock(ReadPodOp readOp) { - Value podRef = readOp.getPodRef(); - StringAttr recordName = readOp.getRecordNameAttr(); + LLVM_DEBUG(llvm::dbgs() << "Begin step 3: update/split other pod ops\n";); + if (failed(applyFullConversion(modOp, target, std::move(patterns)))) { + return failure(); + } - for (Operation &op : *readOp->getBlock()) { - if (&op == readOp.getOperation()) { - return false; + OpBuilder builder(ctx); + for (auto &[podValue, leafValues] : virtualPods) { + if (podValue.use_empty()) { + continue; + } + if (auto newPod = llvm::dyn_cast(podValue.getDefiningOp())) { + builder.setInsertionPointAfter(findVirtualPodMaterializationAnchor(newPod, leafValues)); + materializeVirtualPod(builder, newPod, leafValues); } + } - if (auto writeOp = dyn_cast(&op)) { - if (isSamePodRecord(writeOp, podRef, recordName)) { - return true; + bool erasedDeadPlaceholderOps = false; + do { + SmallVector deadPlaceholderOps; + modOp->walk([&](Operation *op) { + if (auto readOp = llvm::dyn_cast(op)) { + if (readOp.getResult().use_empty()) { + deadPlaceholderOps.push_back(op); + } + return; } - continue; + + if (auto castOp = llvm::dyn_cast(op)) { + if (llvm::all_of(castOp.getResults(), [](Value result) { return result.use_empty(); })) { + deadPlaceholderOps.push_back(op); + } + } + }); + for (Operation *op : deadPlaceholderOps) { + op->erase(); } + erasedDeadPlaceholderOps = !deadPlaceholderOps.empty(); + } while (erasedDeadPlaceholderOps); - if (hasNestedWriteToRecord(op, podRef, recordName)) { - return true; + SmallVector deadOps; + modOp->walk([&](Operation *op) { + if (op != modOp.getOperation() && isOpTriviallyDead(op)) { + deadOps.push_back(op); } + }); + for (Operation *op : deadOps) { + op->erase(); } - return false; + return success(); } /// Return whether `value` is defined within `ancestor` or one of its nested regions. @@ -866,6 +3527,20 @@ static WritePodOp findPrecedingWriteForIfRead(ReadPodOp readOp) { return replacement; } +/// Replace a read with the value from the nearest preceding same-record write in the block. +class FoldReadAfterWriteInBlockPattern final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReadPodOp readOp, PatternRewriter &rewriter) const override { + if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { + rewriter.replaceOp(readOp, writeOp.getValue()); + return success(); + } + return failure(); + } +}; + /// Replace a branch-local read with a value available in the parent block. class ReplaceIfReadPattern final : public OpRewritePattern { public: @@ -887,7 +3562,7 @@ class ReplaceIfReadPattern final : public OpRewritePattern { rewriter.setInsertionPoint(ifOp); rewriter.replaceOp( - readOp, genRead(readOp.getLoc(), readOp.getPodRef(), readOp.getRecordNameAttr(), rewriter) + readOp, genRead(rewriter, readOp.getLoc(), readOp.getPodRef(), readOp.getRecordNameAttr()) .getResult() ); return success(); @@ -1092,8 +3767,8 @@ moveBranchWithoutLiftedWrites(Block *srcBlock, Block &destBlock, ArrayRef slots, - bool isThenBlock, OpBuilder &builder + OpBuilder &bldr, Location loc, Block &block, ValueRange priorYieldValues, + ArrayRef slots, bool isThenBlock ) { SmallVector yieldValues = llvm::to_vector(priorYieldValues); llvm::append_range(yieldValues, llvm::map_range(slots, [isThenBlock](const IfWriteSlot &slot) { @@ -1101,8 +3776,8 @@ static void appendYield( return writeOp ? writeOp.getValue() : slot.incomingValue; })); - builder.setInsertionPointToEnd(&block); - builder.create(loc, yieldValues); + bldr.setInsertionPointToEnd(&block); + bldr.create(loc, yieldValues); } /// One POD record whose value is carried across an SCF loop boundary as an SSA scalar. @@ -1124,7 +3799,7 @@ struct LoopPodSlot { /// Return the tracked loop slot for `podRef.recordName`, or null if not found. static LoopPodSlot * lookupLoopSlot(SmallVectorImpl &slots, Value podRef, StringAttr recordName) { - auto it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { + auto *it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { return slot.matches(podRef, recordName); }); return it == slots.end() ? nullptr : &*it; @@ -1132,7 +3807,7 @@ lookupLoopSlot(SmallVectorImpl &slots, Value podRef, StringAttr rec /// Return whether a loop slot is tracked for `podRef.recordName`. static bool hasLoopSlot(ArrayRef slots, Value podRef, StringAttr recordName) { - auto it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { + const auto *it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { return slot.matches(podRef, recordName); }); return it != slots.end(); @@ -1263,7 +3938,7 @@ class LiftPodWritesFromIfBlocksPattern final : public OpRewritePattern resultTypes = llvm::to_vector(ifOp.getResultTypes()); @@ -1290,15 +3965,15 @@ class LiftPodWritesFromIfBlocksPattern final : public OpRewritePattern newInitArgs = llvm::to_vector(forOp.getInitArgs()); rewriter.setInsertionPoint(forOp); for (const LoopPodSlot &slot : slots) { - newInitArgs.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult()); + newInitArgs.push_back(genRead(rewriter, loc, slot.podRef, slot.recordName).getResult()); } auto newFor = rewriter.create( @@ -1383,7 +4058,7 @@ class LiftPodAccessesFromForLoopPattern final : public OpRewritePattern newResultTypes = llvm::to_vector(whileOp.getResultTypes()); rewriter.setInsertionPoint(whileOp); for (const LoopPodSlot &slot : slots) { - newInits.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult()); + newInits.push_back(genRead(rewriter, loc, slot.podRef, slot.recordName).getResult()); newResultTypes.push_back(slot.type); } @@ -1526,8 +4201,8 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(constrain::EmitEqualityOp op, PatternRewriter &rewriter) const override { + PodType podTy = splittablePod(op.getLhs().getType()); + if (!podTy) { + return failure(); + } + + SmallVector recordChain; + forEachPodLeaf(podTy, recordChain, [&rewriter, &op](const RecordChain &id, Type) { + Value lhsLeaf = genReadAlongPath(rewriter, op.getLoc(), op.getLhs(), id); + Value rhsLeaf = genReadAlongPath(rewriter, op.getLoc(), op.getRhs(), id); + rewriter.create(op.getLoc(), lhsLeaf, rhsLeaf); + }); + rewriter.eraseOp(op); + return success(); + } +}; + /// Apply a greedy rewrite/fold pass over the module body using the provided patterns. static LogicalResult applyGreedily(ModuleOp modOp, RewritePatternSet &&patterns, bool *changed = nullptr) { @@ -1547,15 +4249,14 @@ applyGreedily(ModuleOp modOp, RewritePatternSet &&patterns, bool *changed = null /// Repeatedly lift pod accesses out of supported SCF regions so SROA + mem2reg can eliminate the /// remaining POD storage. -static LogicalResult step3(ModuleOp modOp) { +static LogicalResult step4(ModuleOp modOp) { RewritePatternSet patterns(modOp.getContext()); patterns.add< - ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, LiftPodAccessesFromForLoopPattern, - LiftPodAccessesFromWhileLoopPattern, FoldIfCarriedPodReadAfterWritePattern>( - patterns.getContext() - ); + FoldReadAfterWriteInBlockPattern, ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, + LiftPodAccessesFromForLoopPattern, LiftPodAccessesFromWhileLoopPattern, + FoldIfCarriedPodReadAfterWritePattern, SplitPodInEmitEqualityPattern>(patterns.getContext()); - LLVM_DEBUG(llvm::dbgs() << "Begin step 3: refactor pod ops within SCF regions\n";); + LLVM_DEBUG(llvm::dbgs() << "Begin step 4: refactor pod ops within SCF regions\n";); return applyGreedily(modOp, std::move(patterns)); } @@ -1636,13 +4337,21 @@ class PassImpl : public llzk::pod::impl::PodToScalarPassBase { llvm::dbgs() << "After step 2:\n"; module.dump(); }); + + if (failed(step3(module, symTables, memberRepMap))) { + return signalPassFailure(); + } + LLVM_DEBUG({ + llvm::dbgs() << "After step 3:\n"; + module.dump(); + }); } - if (failed(step3(module))) { + if (failed(step4(module))) { return signalPassFailure(); } LLVM_DEBUG({ - llvm::dbgs() << "After step 3:\n"; + llvm::dbgs() << "After step 4:\n"; module.dump(); }); @@ -1656,6 +4365,11 @@ class PassImpl : public llzk::pod::impl::PodToScalarPassBase { // Cleanup allocations made dead by memory promotion and other dead SSA values. OpPassManager cleanupPM(ModuleOp::getOperationName()); + cleanupPM.addPass(createRemoveUnusedDiscardableAllocationsPass( + RemoveUnusedDiscardableAllocationsPassOptions { + .allocatorOpName = CreateArrayOp::getOperationName().str() + } + )); cleanupPM.addPass(createRemoveUnusedDiscardableAllocationsPass( RemoveUnusedDiscardableAllocationsPassOptions { .allocatorOpName = NewPodOp::getOperationName().str() diff --git a/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp b/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp index 6d63de085..1b332db16 100644 --- a/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp +++ b/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp @@ -969,18 +969,6 @@ LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) { namespace Step1B_InstantiateFunctions { -/// Flatten nested array instantiations by appending any dimensions contributed by the converted -/// element type onto the outer array. This allows wildcard element types to resolve to -/// higher-rank arrays even though LLZK array element types cannot themselves be arrays. -static ArrayType flattenInstantiatedArrayType(ArrayType inputTy, Type convertedElemTy) { - SmallVector mergedDims(inputTy.getDimensionSizes()); - while (ArrayType nestedArrTy = llvm::dyn_cast(convertedElemTy)) { - llvm::append_range(mergedDims, nestedArrTy.getDimensionSizes()); - convertedElemTy = nestedArrTy.getElementType(); - } - return ArrayType::get(convertedElemTy, mergedDims); -} - /// TypeConverter for function instantiation that replaces TypeVarType and symbolic /// ArrayType/StructType parameters with their concrete values determined by unification. class FuncInstTypeConverter : public TypeConverter { @@ -1020,7 +1008,7 @@ class FuncInstTypeConverter : public TypeConverter { if (!changed && newElemTy == inputTy.getElementType()) { return inputTy; } - return flattenInstantiatedArrayType( + return flattenArrayElementType( inputTy.cloneWith(inputTy.getElementType(), updated), newElemTy ); }); diff --git a/lib/Transforms/LLZKTransformationPassPipelines.cpp b/lib/Transforms/LLZKTransformationPassPipelines.cpp index 1c87cfcb6..442bc9f69 100644 --- a/lib/Transforms/LLZKTransformationPassPipelines.cpp +++ b/lib/Transforms/LLZKTransformationPassPipelines.cpp @@ -49,14 +49,14 @@ void buildFullStructInliningPipelineImpl( } pm.addPass(polymorphic::createFlatteningPass(flattening)); - // Run array-to-scalar first because it can split arrays within a pod - // but pod-to-scalar cannot split pods within an array. - if (arrayToScalar) { - pm.addPass(array::createArrayToScalarPass()); - } + // Run pod-to-scalar first because it is able to split `pod.type` used as array element type + // (into parallel arrays) so it should be able to fully remove all `pod.type` usages. if (podToScalar) { pm.addPass(pod::createPodToScalarPass()); } + if (arrayToScalar) { + pm.addPass(array::createArrayToScalarPass()); + } // Canonicalize to remove known-condition `scf.if` regions so struct inlining // can link "@compute" calls to struct members. pm.addPass(mlir::createCanonicalizerPass()); diff --git a/lib/Util/TypeHelper.cpp b/lib/Util/TypeHelper.cpp index 2b7c5fa5d..f511c86e0 100644 --- a/lib/Util/TypeHelper.cpp +++ b/lib/Util/TypeHelper.cpp @@ -564,6 +564,15 @@ bool hasAffineMapAttr(Type type) { bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); } +ArrayType flattenArrayElementType(ArrayType outerArrTy, Type elementType) { + SmallVector mergedDims(outerArrTy.getDimensionSizes()); + while (ArrayType nestedArrTy = llvm::dyn_cast(elementType)) { + llvm::append_range(mergedDims, nestedArrTy.getDimensionSizes()); + elementType = nestedArrTy.getElementType(); + } + return ArrayType::get(elementType, mergedDims); +} + uint64_t computeEmitEqCardinality(Type type) { struct Impl : LLZKTypeSwitch { uint64_t caseBool(IntegerType) { return 1; } diff --git a/test/Transforms/PodToScalar/array_extract_insert.llzk b/test/Transforms/PodToScalar/array_extract_insert.llzk new file mode 100644 index 000000000..227dc8ef8 --- /dev/null +++ b/test/Transforms/PodToScalar/array_extract_insert.llzk @@ -0,0 +1,30 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@lhs: index, @rhs: !felt.type]> +!PairRow = !array.type<2 x !Pair> +!PairMatrix = !array.type<2,2 x !Pair> +module attributes {llzk.lang} { + function.def @extract_then_insert(%src: !PairMatrix) -> !PairMatrix { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %row = array.extract %src[%c0] : !PairMatrix + %dst = array.new : !PairMatrix + array.insert %dst[%c1] = %row : !PairMatrix, !PairRow + function.return %dst : !PairMatrix + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @extract_then_insert +// CHECK-SAME: (%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x !felt.type>) +// CHECK-SAME: -> (!array.type<2,2 x index>, !array.type<2,2 x !felt.type>) { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_0]]{{\[}}%[[VAL_2]]] : <2,2 x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_1]]{{\[}}%[[VAL_2]]] : <2,2 x !felt.type> +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x !felt.type> +// CHECK-NEXT: array.insert %[[VAL_6]]{{\[}}%[[VAL_3]]] = %[[VAL_4]] : <2,2 x index>, <2 x index> +// CHECK-NEXT: array.insert %[[VAL_7]]{{\[}}%[[VAL_3]]] = %[[VAL_5]] : <2,2 x !felt.type>, <2 x !felt.type> +// CHECK-NEXT: function.return %[[VAL_6]], %[[VAL_7]] : !array.type<2,2 x index>, !array.type<2,2 x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk b/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk new file mode 100644 index 000000000..93415f4ba --- /dev/null +++ b/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk @@ -0,0 +1,78 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Elem = !pod.type<[@vals: !array.type<3 x index>, @tag: index]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @pass_through(%arg: !ElemArray) -> !ElemArray { + function.return %arg : !ElemArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @pass_through( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,3 x index>, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x index> +// CHECK-SAME: ) -> (!array.type<2,3 x index>, !array.type<2 x index>) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2,3 x index>, !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Elem = !pod.type<[@vals: !array.type<3 x index>]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @pack( + %a0: index, %a1: index, %a2: index, %b0: index, %b1: index, %b2: index + ) -> !ElemArray { + %vals0 = array.new %a0, %a1, %a2 : !array.type<3 x index> + %vals1 = array.new %b0, %b1, %b2 : !array.type<3 x index> + %lhs = pod.new : !Elem + pod.write %lhs[@vals] = %vals0 : !Elem, !array.type<3 x index> + %rhs = pod.new : !Elem + pod.write %rhs[@vals] = %vals1 : !Elem, !array.type<3 x index> + %arr = array.new %lhs, %rhs : !ElemArray + function.return %arr : !ElemArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @pack( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]]: index, %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, %[[VAL_5:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> !array.type<2,3 x index> { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = array.new %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : <3 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : <3 x index> +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = array.new : <2,3 x index> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.insert %[[VAL_8]]{{\[}}%[[VAL_9]]] = %[[VAL_6]] : <2,3 x index>, <3 x index> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.insert %[[VAL_8]]{{\[}}%[[VAL_10]]] = %[[VAL_7]] : <2,3 x index>, <3 x index> +// CHECK-NEXT: function.return %[[VAL_8]] : !array.type<2,3 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Elem = !pod.type<[@vals: !array.type<3 x index>, @tag: index]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @update_first(%arr: !ElemArray, %newVals: !array.type<3 x index>) -> !ElemArray { + %c0 = arith.constant 0 : index + %elem = array.read %arr[%c0] : !ElemArray, !Elem + pod.write %elem[@vals] = %newVals : !Elem, !array.type<3 x index> + array.write %arr[%c0] = %elem : !ElemArray, !Elem + function.return %arr : !ElemArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @update_first( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,3 x index>, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]]: !array.type<3 x index> +// CHECK-SAME: ) -> (!array.type<2,3 x index>, !array.type<2 x index>) { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_0]]{{\[}}%[[VAL_3]]] : <2,3 x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_1]]{{\[}}%[[VAL_3]]] : <2 x index>, index +// CHECK-NEXT: array.insert %[[VAL_0]]{{\[}}%[[VAL_3]]] = %[[VAL_2]] : <2,3 x index>, <3 x index> +// CHECK-NEXT: array.write %[[VAL_1]]{{\[}}%[[VAL_3]]] = %[[VAL_5]] : <2 x index>, index +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2,3 x index>, !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_length.llzk b/test/Transforms/PodToScalar/array_length.llzk new file mode 100644 index 000000000..0257ee0c0 --- /dev/null +++ b/test/Transforms/PodToScalar/array_length.llzk @@ -0,0 +1,129 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @len_multi_leaf(%arr: !array.type<2 x !Pair>, %dim: index) -> index { + %len = array.len %arr, %dim : !array.type<2 x !Pair> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_multi_leaf(%[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[DIM]] : <2 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +// Tests: preserve the `array.len` semantics so a dynamic dimension selection +// cannot observe dimensions that did not exist before pod flattening. +!Elem = !pod.type<[@vals: !array.type<3 x index>]> +module attributes {llzk.lang} { + function.def @len_leaf_array_dynamic_dim(%arr: !array.type<2 x !Elem>, %dim: index) -> index { + %len = array.len %arr, %dim : !array.type<2 x !Elem> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_leaf_array_dynamic_dim(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<2,3 x index>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[SHAPE:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[SHAPE]], %[[DIM]] : <2 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + function.def @len_empty_leaf_static_create(%dim: index) -> index { + %arr = array.new : !array.type<4 x !pod.type<[]>> + %len = array.len %arr, %dim : !array.type<4 x !pod.type<[]>> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_static_create(%[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <4 x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <4 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#map = affine_map<()[s0] -> (s0)> +module attributes {llzk.lang} { + function.def @len_empty_leaf_array(%n: index, %dim: index) -> index { + %arr = array.new{()[%n]} : !array.type<#map x !pod.type<[]>> + %len = array.len %arr, %dim : !array.type<#map x !pod.type<[]>> + function.return %len : index + } +} +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_array(%[[N:[0-9a-zA-Z_\.]+]]: index, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new{()[%[[N]]]} : <#[[$MAP]] x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#[[$MAP]] x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + function.def @len_empty_leaf_static_arg(%arr: !array.type<4 x !pod.type<[]>>, %dim: index) -> index { + %len = array.len %arr, %dim : !array.type<4 x !pod.type<[]>> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_static_arg(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<4 x index>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <4 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#map_arg = affine_map<()[s0] -> (s0)> +module attributes {llzk.lang} { + function.def @len_empty_leaf_affine_arg( + %arr: !array.type<#map_arg x !pod.type<[]>>, %dim: index + ) -> index { + %len = array.len %arr, %dim : !array.type<#map_arg x !pod.type<[]>> + function.return %len : index + } +} +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_affine_arg(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#[[$MAP]] x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#map_call = affine_map<()[s0] -> (s0)> +module attributes {llzk.lang} { + function.def @len_empty_leaf_affine_sink( + %arr: !array.type<#map_call x !pod.type<[]>>, %dim: index + ) -> index { + %len = array.len %arr, %dim : !array.type<#map_call x !pod.type<[]>> + function.return %len : index + } + + function.def @len_empty_leaf_affine_call( + %arr: !array.type<#map_call x !pod.type<[]>>, %dim: index + ) -> index { + %len = function.call @len_empty_leaf_affine_sink(%arr, %dim) + : (!array.type<#map_call x !pod.type<[]>>, index) -> index + function.return %len : index + } +} +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_affine_sink(%[[ARR0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, %[[DIM0:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN0:[0-9a-zA-Z_\.]+]] = array.len %[[ARR0]], %[[DIM0]] : <#[[$MAP]] x index> +// CHECK-NEXT: function.return %[[LEN0]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @len_empty_leaf_affine_call(%[[ARR1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, %[[DIM1:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN1:[0-9a-zA-Z_\.]+]] = function.call @len_empty_leaf_affine_sink(%[[ARR1]], %[[DIM1]]) : (!array.type<#[[$MAP]] x index>, index) -> index +// CHECK-NEXT: function.return %[[LEN1]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk b/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk new file mode 100644 index 000000000..9297ccf06 --- /dev/null +++ b/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk @@ -0,0 +1,85 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +#inner = affine_map<()[s0] -> (s0 + 1)> +!Elem = !pod.type<[@vals: !array.type<#inner x index>]> +!ElemArray = !array.type<#outer x !Elem> +module attributes {llzk.lang} { + function.def @uninitialized(%n: index) -> !ElemArray { + // The affine map initializer here is for the `#outer` map in `!ElemArray`, not the `#inner` map in `!Elem`. + // The `#inner` map can only be initialized via an `array.new` that is then written to the pod which is not + // present in this example. Thus, when `pod-to-scalar` combines the arrays to create a single, 2-D array, + // the second dimension size must be `?` because no affine instantiation exists to provide as an affine + // instantiation. However, the return value itself is unchanged so a `unifiable_cast` is used to convert. + %arr = array.new{()[%n]} : !ElemArray + function.return %arr : !ElemArray + } +} +// CHECK: #[[$M1:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK: #[[$M2:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @uninitialized(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> !array.type<#[[$M1]],#[[$M2]] x index> { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <#[[$M1]],? x index> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_1]] : (!array.type<#[[$M1]],? x index>) -> !array.type<#[[$M1]],#[[$M2]] x index> +// CHECK-NEXT: function.return %[[VAL_2]] : !array.type<#[[$M1]],#[[$M2]] x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#inner = affine_map<()[s0] -> (s0 + 1)> +!Leaf = !array.type<#inner x index> +!Elem = !pod.type<[@vals: !Leaf]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @materialized(%n: index) -> !ElemArray { + %vals0 = array.new{()[%n]} : !Leaf + %vals1 = array.new{()[%n]} : !Leaf + %lhs = pod.new()[%n] : !Elem + pod.write %lhs[@vals] = %vals0 : !Elem, !Leaf + %rhs = pod.new()[%n] : !Elem + pod.write %rhs[@vals] = %vals1 : !Elem, !Leaf + %arr = array.new %lhs, %rhs : !ElemArray + function.return %arr : !ElemArray + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @materialized(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> !array.type<2,#[[$M]] x index> { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <#[[$M]] x index> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <#[[$M]] x index> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <2,#[[$M]] x index> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.insert %[[VAL_3]]{{\[}}%[[VAL_4]]] = %[[VAL_1]] : <2,#[[$M]] x index>, <#[[$M]] x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.insert %[[VAL_3]]{{\[}}%[[VAL_5]]] = %[[VAL_2]] : <2,#[[$M]] x index>, <#[[$M]] x index> +// CHECK-NEXT: function.return %[[VAL_3]] : !array.type<2,#[[$M]] x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#inner = affine_map<()[s0] -> (s0 + 1)> +!Leaf = !array.type<#inner x index> +!Elem = !pod.type<[@vals: !Leaf]> +!ElemArray = !array.type<2 x !Elem> +!Outer = !pod.type<[@items: !ElemArray]> +module attributes {llzk.lang} { + function.def @sink(%arr: !ElemArray) { + function.return + } + + function.def @source(%p: !Outer) { + %items = pod.read %p[@items] : !Outer, !ElemArray + function.call @sink(%items) : (!ElemArray) -> () + function.return + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @sink(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,#[[$M]] x index>) { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: function.def @source(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2,#[[$M]] x index>) { +// CHECK-NEXT: function.call @sink(%[[VAL_1]]) : (!array.type<2,#[[$M]] x index>) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk b/test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk new file mode 100644 index 000000000..da5fa9a2c --- /dev/null +++ b/test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk @@ -0,0 +1,20 @@ +// RUN: llzk-opt %s -split-input-file -llzk-pod-to-scalar -verify-diagnostics + +#inner = affine_map<()[s0] -> (s0 + 1)> +!InnerArr = !array.type<#inner x index> +!Elem = !pod.type<[@vals: !InnerArr]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @conflict(%m: index, %n: index) -> !ElemArray { + %vals0 = array.new{()[%m]} : !InnerArr + %vals1 = array.new{()[%n]} : !InnerArr + %lhs = pod.new()[%m] : !Elem + pod.write %lhs[@vals] = %vals0 : !Elem, !InnerArr + %rhs = pod.new()[%n] : !Elem + pod.write %rhs[@vals] = %vals1 : !Elem, !InnerArr + // expected-error@+2 {{'array.new' op with POD elements having conflicting affine map instantiations cannot be promoted to higher dimensional array}} + // expected-error@+1 {{failed to legalize operation 'array.new'}} + %arr = array.new %lhs, %rhs : !ElemArray + function.return %arr : !ElemArray + } +} diff --git a/test/Transforms/PodToScalar/array_read_from_pod_field.llzk b/test/Transforms/PodToScalar/array_read_from_pod_field.llzk new file mode 100644 index 000000000..e5062bebd --- /dev/null +++ b/test/Transforms/PodToScalar/array_read_from_pod_field.llzk @@ -0,0 +1,22 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<2 x !Item>]> +module attributes {llzk.lang} { + function.def @read_item(%p: !Outer, %i: index) -> index { + %items = pod.read %p[@items] : !Outer, !array.type<2 x !Item> + %item = array.read %items[%i] : !array.type<2 x !Item>, !Item + %x = pod.read %item[@x] : !Item, index + function.return %x : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @read_item(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[IDX:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[UNUSED0:[0-9a-zA-Z_\.]+]] = array.read %[[ARR]]{{\[}}%[[C0]]] : <2 x index>, index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[UNUSED1:[0-9a-zA-Z_\.]+]] = array.read %[[ARR]]{{\[}}%[[C1]]] : <2 x index>, index +// CHECK-NEXT: %[[SEL:[0-9a-zA-Z_\.]+]] = array.read %[[ARR]]{{\[}}%[[IDX]]] : <2 x index>, index +// CHECK-NEXT: function.return %[[SEL]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk b/test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk new file mode 100644 index 000000000..33be2b4af --- /dev/null +++ b/test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk @@ -0,0 +1,22 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + function.def @read_unwritten_dynamic(%n: index, %i: index) -> index { + %p = pod.new()[%n] : !Outer + %items = pod.read %p[@items] : !Outer, !array.type<#outer x !Item> + %item = array.read %items[%i] : !array.type<#outer x !Item>, !Item + %x = pod.read %item[@x] : !Item, index + function.return %x : index + } +} +// CHECK: #[[$ATTR_0:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @read_unwritten_dynamic(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#[[$ATTR_0]] x index> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_2]]{{\[}}%[[VAL_1]]] : <#[[$ATTR_0]] x index>, index +// CHECK-NEXT: function.return %[[VAL_3]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk b/test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk new file mode 100644 index 000000000..f8ca47829 --- /dev/null +++ b/test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk @@ -0,0 +1,26 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#inner = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#inner x !Item>, @tag: index]> +!OuterArray = !array.type<2 x !Outer> +module attributes {llzk.lang} { + function.def @write_elem(%arr: !OuterArray, %elem: !Outer, %i: index) -> !OuterArray { + array.write %arr[%i] = %elem : !OuterArray, !Outer + function.return %arr : !OuterArray + } +} +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @write_elem( +// CHECK-SAME: %[[ARR_ITEMS:[0-9a-zA-Z_\.]+]]: !array.type<2,#[[$MAP]] x index>, +// CHECK-SAME: %[[ARR_TAG:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[ELEM_ITEMS:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, +// CHECK-SAME: %[[ELEM_TAG:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[IDX:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> (!array.type<2,#[[$MAP]] x index>, !array.type<2 x index>) { +// CHECK-NEXT: array.insert %[[ARR_ITEMS]]{{\[}}%[[IDX]]] = %[[ELEM_ITEMS]] : <2,#[[$MAP]] x index>, <#[[$MAP]] x index> +// CHECK-NEXT: array.write %[[ARR_TAG]]{{\[}}%[[IDX]]] = %[[ELEM_TAG]] : <2 x index>, index +// CHECK-NEXT: function.return %[[ARR_ITEMS]], %[[ARR_TAG]] : !array.type<2,#[[$MAP]] x index>, !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/bool_quantifiers.llzk b/test/Transforms/PodToScalar/bool_quantifiers.llzk new file mode 100644 index 000000000..126f38c5c --- /dev/null +++ b/test/Transforms/PodToScalar/bool_quantifiers.llzk @@ -0,0 +1,79 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@x: index, @y: index]> +module attributes {llzk.lang} { + function.def @forall_arg(%arr: !array.type<2 x !Pair>, %limit: index) -> i1 attributes {function.allow_non_native_field_ops} { + %all = bool.forall %elt in %arr : !array.type<2 x !Pair> { + %x = pod.read %elt[@x] : !Pair, index + %ok = arith.cmpi slt, %x, %limit : index + bool.yield %ok + } + function.return %all : i1 + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @forall_arg( +// CHECK-SAME: %[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[LIMIT:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> i1 attributes {function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[LB:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[UB:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[LB]] : <2 x index> +// CHECK-NEXT: %[[STEP:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[INIT:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: %[[RES:[0-9a-zA-Z_\.]+]] = scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ACC:[0-9a-zA-Z_\.]+]] = %[[INIT]]) -> (i1) { +// CHECK-DAG: %[[X:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_X]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[Y:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_Y]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[OK:[0-9a-zA-Z_\.]+]] = arith.cmpi slt, %[[X]], %[[LIMIT]] : index +// CHECK-NEXT: %[[COMBINED:[0-9a-zA-Z_\.]+]] = bool.and %[[ACC]], %[[OK]] : i1, i1 +// CHECK-NEXT: scf.yield %[[COMBINED]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: function.return %[[RES]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Pair = !pod.type<[@x: index, @y: index]> +module attributes {llzk.lang} { + function.def @exists_array_new(%x0: index, %y0: index, %x1: index, %y1: index) -> i1 attributes {function.allow_non_native_field_ops} { + %p0 = pod.new { @x = %x0, @y = %y0 } : !Pair + %p1 = pod.new { @x = %x1, @y = %y1 } : !Pair + %arr = array.new %p0, %p1 : !array.type<2 x !Pair> + %any = bool.exists %elt in %arr : !array.type<2 x !Pair> { + %y = pod.read %elt[@y] : !Pair, index + %ok = arith.cmpi eq, %y, %y1 : index + bool.yield %ok + } + function.return %any : i1 + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @exists_array_new( +// CHECK-SAME: %[[X0:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[Y0:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[X1:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[Y1:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> i1 attributes {function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[ARR_X:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: %[[C0X:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.write %[[ARR_X]][%[[C0X]]] = %[[X0]] : <2 x index>, index +// CHECK-NEXT: %[[C1X:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[ARR_X]][%[[C1X]]] = %[[X1]] : <2 x index>, index +// CHECK-NEXT: %[[ARR_Y:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: %[[C0Y:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.write %[[ARR_Y]][%[[C0Y]]] = %[[Y0]] : <2 x index>, index +// CHECK-NEXT: %[[C1Y:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[ARR_Y]][%[[C1Y]]] = %[[Y1]] : <2 x index>, index +// CHECK-NEXT: %[[LB:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[UB:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[LB]] : <2 x index> +// CHECK-NEXT: %[[STEP:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[INIT:[0-9a-zA-Z_\.]+]] = arith.constant false +// CHECK-NEXT: %[[RES:[0-9a-zA-Z_\.]+]] = scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ACC:[0-9a-zA-Z_\.]+]] = %[[INIT]]) -> (i1) { +// CHECK-DAG: %[[X:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_X]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[Y:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_Y]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[OK:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[Y]], %[[Y1]] : index +// CHECK-NEXT: %[[COMBINED:[0-9a-zA-Z_\.]+]] = bool.or %[[ACC]], %[[OK]] : i1, i1 +// CHECK-NEXT: scf.yield %[[COMBINED]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: function.return %[[RES]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk b/test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk new file mode 100644 index 000000000..e1655de75 --- /dev/null +++ b/test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk @@ -0,0 +1,30 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Pair = !pod.type<[@x: index, @y: !felt.type]> +!Holder = !pod.type<[@pairs: !array.type<#outer x !Pair>]> +module attributes {llzk.lang} { + function.def @sink(%arr: !array.type<#outer x !Pair>) -> index { + %c0 = arith.constant 0 : index + %len = array.len %arr, %c0 : !array.type<#outer x !Pair> + function.return %len : index + } + + function.def @main(%holder: !Holder) -> index { + %pairs = pod.read %holder[@pairs] : !Holder, !array.type<#outer x !Pair> + %len = function.call @sink(%pairs) : (!array.type<#outer x !Pair>) -> index + function.return %len : index + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @sink(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> index { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = array.len %[[VAL_0]], %[[VAL_2]] : <#[[$M]] x index> +// CHECK-NEXT: function.return %[[VAL_3]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @main(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> index { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = function.call @sink(%[[VAL_4]], %[[VAL_5]]) : (!array.type<#[[$M]] x index>, !array.type<#[[$M]] x !felt.type>) -> index +// CHECK-NEXT: function.return %[[VAL_6]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/circom_decomp_prod.llzk b/test/Transforms/PodToScalar/circom_decomp_prod.llzk new file mode 100644 index 000000000..7a1e01dd2 --- /dev/null +++ b/test/Transforms/PodToScalar/circom_decomp_prod.llzk @@ -0,0 +1,545 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s 2>&1 | FileCheck --enable-var-scope %s + +!F = !felt.type<"bn128"> +module attributes {llzk.lang = "circom", llzk.main = !struct.type<@DecomposeProduct_1>} { + struct.def @Num2Bits_0 { + struct.member @out : !array.type<8 x !F> {llzk.pub} + function.def @compute(%arg0: !F) -> !struct.type<@Num2Bits_0> attributes {function.allow_non_native_field_ops} { + %self = struct.new : <@Num2Bits_0> + %nondet = llzk.nondet : !array.type<8 x !F> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %felt_const_1 = felt.const 1 : !F + %felt_const_0_0 = felt.const 0 : !F + %0:3 = scf.while (%arg1 = %felt_const_1, %arg2 = %felt_const_0, %arg3 = %felt_const_0_0) : (!F, !F, !F) -> (!F, !F, !F) { + %felt_const_8_1 = felt.const 8 : !F + %1 = bool.cmp lt(%arg3, %felt_const_8_1) : !F, !F + scf.condition(%1) %arg1, %arg2, %arg3 : !F, !F, !F + } do { + ^bb0(%arg1: !F, %arg2: !F, %arg3: !F): + %1 = felt.shr %arg0, %arg3 : !F, !F + %felt_const_1_1 = felt.const 1 : !F + %2 = felt.bit_and %1, %felt_const_1_1 : !F, !F + %3 = cast.toindex %arg3 : !F + array.write %nondet[%3] = %2 : <8 x !F>, !F + %4 = cast.toindex %arg3 : !F + %5 = array.read %nondet[%4] : <8 x !F>, !F + %6 = felt.mul %5, %felt_const_1 : !F, !F + %7 = felt.add %arg2, %6 : !F, !F + %8 = felt.add %arg1, %arg1 : !F, !F + %felt_const_1_2 = felt.const 1 : !F + %9 = felt.add %arg3, %felt_const_1_2 : !F, !F + scf.yield %8, %7, %9 : !F, !F, !F + } + struct.writem %self[@out] = %nondet : <@Num2Bits_0>, !array.type<8 x !F> + function.return %self : !struct.type<@Num2Bits_0> + } + function.def @constrain(%arg0: !struct.type<@Num2Bits_0>, %arg1: !F) attributes {function.allow_non_native_field_ops} { + %0 = struct.readm %arg0[@out] : <@Num2Bits_0>, !array.type<8 x !F> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %felt_const_1 = felt.const 1 : !F + %felt_const_0_0 = felt.const 0 : !F + %1:3 = scf.while (%arg2 = %felt_const_0, %arg3 = %felt_const_1, %arg4 = %felt_const_0_0) : (!F, !F, !F) -> (!F, !F, !F) { + %felt_const_8_1 = felt.const 8 : !F + %2 = bool.cmp lt(%arg4, %felt_const_8_1) : !F, !F + scf.condition(%2) %arg2, %arg3, %arg4 : !F, !F, !F + } do { + ^bb0(%arg2: !F, %arg3: !F, %arg4: !F): + %2 = cast.toindex %arg4 : !F + %3 = array.read %0[%2] : <8 x !F>, !F + %4 = cast.toindex %arg4 : !F + %5 = array.read %0[%4] : <8 x !F>, !F + %felt_const_1_1 = felt.const 1 : !F + %6 = felt.sub %5, %felt_const_1_1 : !F, !F + %7 = felt.mul %3, %6 : !F, !F + %felt_const_0_2 = felt.const 0 : !F + constrain.eq %7, %felt_const_0_2 : !F, !F + %8 = cast.toindex %arg4 : !F + %9 = array.read %0[%8] : <8 x !F>, !F + %10 = felt.mul %9, %felt_const_1 : !F, !F + %11 = felt.add %arg2, %10 : !F, !F + %12 = felt.add %arg3, %arg3 : !F, !F + %felt_const_1_3 = felt.const 1 : !F + %13 = felt.add %arg4, %felt_const_1_3 : !F, !F + scf.yield %11, %12, %13 : !F, !F, !F + } + constrain.eq %1#0, %arg1 : !F, !F + function.return + } + } + struct.def @DecomposeProduct_1 { + struct.member @high : !array.type<8 x !F> {llzk.pub} + struct.member @low : !array.type<8 x !F> {llzk.pub} + struct.member @u16s : !array.type<8 x !F> + struct.member @bits_low : !array.type<8 x !struct.type<@Num2Bits_0>> + struct.member @bits_low$inputs : !array.type<8 x !pod.type<[@in: !F]>> + struct.member @bits_high : !array.type<8 x !struct.type<@Num2Bits_0>> + struct.member @bits_high$inputs : !array.type<8 x !pod.type<[@in: !F]>> + function.def @compute(%arg0: !array.type<8 x !F>, %arg1: !array.type<8 x !F>) -> !struct.type<@DecomposeProduct_1> attributes {function.allow_non_native_field_ops} { + %self = struct.new : <@DecomposeProduct_1> + %nondet = llzk.nondet : !array.type<8 x !F> + %nondet_0 = llzk.nondet : !array.type<8 x !F> + %nondet_1 = llzk.nondet : !array.type<8 x !F> + %array = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %arg2 = %c0 to %c8 step %c1 { + %1 = array.read %array[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %c1_16 = arith.constant 1 : index + pod.write %1[@count] = %c1_16 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + array.write %array[%arg2] = %1 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } + %array_2 = array.new : <8 x !pod.type<[@in: !F]>> + %array_3 = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> + %c8_4 = arith.constant 8 : index + %c0_5 = arith.constant 0 : index + %c1_6 = arith.constant 1 : index + scf.for %arg2 = %c0_5 to %c8_4 step %c1_6 { + %1 = array.read %array_3[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %c1_16 = arith.constant 1 : index + pod.write %1[@count] = %c1_16 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + array.write %array_3[%arg2] = %1 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } + %array_7 = array.new : <8 x !pod.type<[@in: !F]>> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %0:3 = scf.while (%arg2 = %array_7, %arg3 = %array_2, %arg4 = %felt_const_0) : (!array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F) -> (!array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F) { + %felt_const_8_16 = felt.const 8 : !F + %1 = bool.cmp lt(%arg4, %felt_const_8_16) : !F, !F + scf.condition(%1) %arg2, %arg3, %arg4 : !array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F + } do { + ^bb0(%arg2: !array.type<8 x !pod.type<[@in: !F]>>, %arg3: !array.type<8 x !pod.type<[@in: !F]>>, %arg4: !F): + %1 = cast.toindex %arg4 : !F + %2 = array.read %arg0[%1] : <8 x !F>, !F + %3 = cast.toindex %arg4 : !F + %4 = array.read %arg1[%3] : <8 x !F>, !F + %5 = felt.mul %2, %4 : !F, !F + %6 = cast.toindex %arg4 : !F + array.write %nondet[%6] = %5 : <8 x !F>, !F + %7 = cast.toindex %arg4 : !F + %8 = array.read %nondet[%7] : <8 x !F>, !F + %felt_const_256 = felt.const 256 : !F + %9 = felt.umod %8, %felt_const_256 : !F, !F + %10 = cast.toindex %arg4 : !F + array.write %nondet_0[%10] = %9 : <8 x !F>, !F + %11 = cast.toindex %arg4 : !F + %12 = array.read %nondet_0[%11] : <8 x !F>, !F + %13 = cast.toindex %arg4 : !F + %14 = array.read %arg3[%13] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + pod.write %14[@in] = %12 : <[@in: !F]>, !F + %15 = cast.toindex %arg4 : !F + array.write %arg3[%15] = %14 : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %16 = cast.toindex %arg4 : !F + %17 = array.read %array[%16] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %18 = pod.read %17[@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c1_16 = arith.constant 1 : index + %19 = arith.subi %18, %c1_16 : index + pod.write %17[@count] = %19 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c0_17 = arith.constant 0 : index + %20 = arith.cmpi eq, %19, %c0_17 : index + scf.if %20 { + %37 = pod.read %14[@in] : <[@in: !F]>, !F + %38 = function.call @Num2Bits_0::@compute(%37) : (!F) -> !struct.type<@Num2Bits_0> + pod.write %17[@comp] = %38 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + %39 = cast.toindex %arg4 : !F + array.write %array[%39] = %17 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } else { + } + %21 = cast.toindex %arg4 : !F + %22 = array.read %nondet[%21] : <8 x !F>, !F + %felt_const_256_18 = felt.const 256 : !F + %23 = felt.uintdiv %22, %felt_const_256_18 : !F, !F + %felt_const_256_19 = felt.const 256 : !F + %24 = felt.umod %23, %felt_const_256_19 : !F, !F + %25 = cast.toindex %arg4 : !F + array.write %nondet_1[%25] = %24 : <8 x !F>, !F + %26 = cast.toindex %arg4 : !F + %27 = array.read %nondet_1[%26] : <8 x !F>, !F + %28 = cast.toindex %arg4 : !F + %29 = array.read %arg2[%28] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + pod.write %29[@in] = %27 : <[@in: !F]>, !F + %30 = cast.toindex %arg4 : !F + array.write %arg2[%30] = %29 : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %31 = cast.toindex %arg4 : !F + %32 = array.read %array_3[%31] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %33 = pod.read %32[@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c1_20 = arith.constant 1 : index + %34 = arith.subi %33, %c1_20 : index + pod.write %32[@count] = %34 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c0_21 = arith.constant 0 : index + %35 = arith.cmpi eq, %34, %c0_21 : index + scf.if %35 { + %37 = pod.read %29[@in] : <[@in: !F]>, !F + %38 = function.call @Num2Bits_0::@compute(%37) : (!F) -> !struct.type<@Num2Bits_0> + pod.write %32[@comp] = %38 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + %39 = cast.toindex %arg4 : !F + array.write %array_3[%39] = %32 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } else { + } + %felt_const_1 = felt.const 1 : !F + %36 = felt.add %arg4, %felt_const_1 : !F, !F + scf.yield %arg2, %arg3, %36 : !array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F + } + struct.writem %self[@bits_high$inputs] = %0#0 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %array_8 = array.new : <8 x !struct.type<@Num2Bits_0>> + %c8_9 = arith.constant 8 : index + %c0_10 = arith.constant 0 : index + %c1_11 = arith.constant 1 : index + scf.for %arg2 = %c0_10 to %c8_9 step %c1_11 { + %1 = array.read %array_3[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %2 = pod.read %1[@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + array.write %array_8[%arg2] = %2 : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + } + struct.writem %self[@bits_high] = %array_8 : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + struct.writem %self[@bits_low$inputs] = %0#1 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %array_12 = array.new : <8 x !struct.type<@Num2Bits_0>> + %c8_13 = arith.constant 8 : index + %c0_14 = arith.constant 0 : index + %c1_15 = arith.constant 1 : index + scf.for %arg2 = %c0_14 to %c8_13 step %c1_15 { + %1 = array.read %array[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %2 = pod.read %1[@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + array.write %array_12[%arg2] = %2 : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + } + struct.writem %self[@bits_low] = %array_12 : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + struct.writem %self[@u16s] = %nondet : <@DecomposeProduct_1>, !array.type<8 x !F> + struct.writem %self[@low] = %nondet_0 : <@DecomposeProduct_1>, !array.type<8 x !F> + struct.writem %self[@high] = %nondet_1 : <@DecomposeProduct_1>, !array.type<8 x !F> + function.return %self : !struct.type<@DecomposeProduct_1> + } + function.def @constrain(%arg0: !struct.type<@DecomposeProduct_1>, %arg1: !array.type<8 x !F>, %arg2: !array.type<8 x !F>) attributes {function.allow_non_native_field_ops} { + %0 = struct.readm %arg0[@high] : <@DecomposeProduct_1>, !array.type<8 x !F> + %1 = struct.readm %arg0[@low] : <@DecomposeProduct_1>, !array.type<8 x !F> + %2 = struct.readm %arg0[@u16s] : <@DecomposeProduct_1>, !array.type<8 x !F> + %3 = struct.readm %arg0[@bits_low] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + %4 = struct.readm %arg0[@bits_low$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %5 = struct.readm %arg0[@bits_high] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + %6 = struct.readm %arg0[@bits_high$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %7 = scf.while (%arg3 = %felt_const_0) : (!F) -> !F { + %felt_const_8_3 = felt.const 8 : !F + %8 = bool.cmp lt(%arg3, %felt_const_8_3) : !F, !F + scf.condition(%8) %arg3 : !F + } do { + ^bb0(%arg3: !F): + %8 = cast.toindex %arg3 : !F + %9 = array.read %arg1[%8] : <8 x !F>, !F + %10 = cast.toindex %arg3 : !F + %11 = array.read %arg2[%10] : <8 x !F>, !F + %12 = felt.mul %9, %11 : !F, !F + %13 = cast.toindex %arg3 : !F + %14 = array.read %2[%13] : <8 x !F>, !F + constrain.eq %14, %12 : !F, !F + %15 = cast.toindex %arg3 : !F + %16 = array.read %1[%15] : <8 x !F>, !F + %17 = cast.toindex %arg3 : !F + %18 = array.read %4[%17] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %19 = pod.read %18[@in] : <[@in: !F]>, !F + constrain.eq %19, %16 : !F, !F + %20 = cast.toindex %arg3 : !F + %21 = array.read %0[%20] : <8 x !F>, !F + %22 = cast.toindex %arg3 : !F + %23 = array.read %6[%22] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %24 = pod.read %23[@in] : <[@in: !F]>, !F + constrain.eq %24, %21 : !F, !F + %25 = cast.toindex %arg3 : !F + %26 = array.read %2[%25] : <8 x !F>, !F + %27 = cast.toindex %arg3 : !F + %28 = array.read %1[%27] : <8 x !F>, !F + %felt_const_256 = felt.const 256 : !F + %29 = cast.toindex %arg3 : !F + %30 = array.read %0[%29] : <8 x !F>, !F + %31 = felt.mul %felt_const_256, %30 : !F, !F + %32 = felt.add %28, %31 : !F, !F + constrain.eq %26, %32 : !F, !F + %felt_const_1 = felt.const 1 : !F + %33 = felt.add %arg3, %felt_const_1 : !F, !F + scf.yield %33 : !F + } + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %arg3 = %c0 to %c8 step %c1 { + %8 = array.read %5[%arg3] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + %9 = array.read %6[%arg3] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %10 = pod.read %9[@in] : <[@in: !F]>, !F + function.call @Num2Bits_0::@constrain(%8, %10) : (!struct.type<@Num2Bits_0>, !F) -> () + } + %c8_0 = arith.constant 8 : index + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + scf.for %arg3 = %c0_1 to %c8_0 step %c1_2 { + %8 = array.read %3[%arg3] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + %9 = array.read %4[%arg3] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %10 = pod.read %9[@in] : <[@in: !F]>, !F + function.call @Num2Bits_0::@constrain(%8, %10) : (!struct.type<@Num2Bits_0>, !F) -> () + } + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang = "circom", llzk.main = !struct.type<@DecomposeProduct_1>} { +// CHECK-NEXT: struct.def @Num2Bits_0 { +// CHECK-NEXT: struct.member @out : !array.type<8 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) -> !struct.type<@Num2Bits_0> attributes {function.allow_non_native_field_ops, function.allow_witness} { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Num2Bits_0> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_5:[0-9a-zA-Z_\.]+]] = %[[VAL_3]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_5]], %[[VAL_6]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_7]]) %[[VAL_5]] : !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_8:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.shr %[[VAL_0]], %[[VAL_8]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = felt.bit_and %[[VAL_9]], %[[VAL_10]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_8]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_2]]{{\[}}%[[VAL_12]]] = %[[VAL_11]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_8]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_2]]{{\[}}%[[VAL_13]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_8]], %[[VAL_15]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_16]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_1]][@out] = %[[VAL_2]] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_17:[0-9a-zA-Z_\.]+]]: !struct.type<@Num2Bits_0>, %[[VAL_18:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@out] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]]:2 = scf.while (%[[VAL_24:[0-9a-zA-Z_\.]+]] = %[[VAL_20]], %[[VAL_25:[0-9a-zA-Z_\.]+]] = %[[VAL_22]]) : (!felt.type<"bn128">, !felt.type<"bn128">) -> (!felt.type<"bn128">, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_26:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_27:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_25]], %[[VAL_26]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_27]]) %[[VAL_24]], %[[VAL_25]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_28:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_29:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_30:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_29]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_31:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_30]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_32:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_29]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_33:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_32]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_34:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_35:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_33]], %[[VAL_34]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_36:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_31]], %[[VAL_35]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_37:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_36]], %[[VAL_37]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_38:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_29]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_39:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_38]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_40:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_39]], %[[VAL_21]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_41:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_28]], %[[VAL_40]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_42:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_43:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_29]], %[[VAL_42]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_41]], %[[VAL_43]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: constrain.eq %[[VAL_23]]#0, %[[VAL_18]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: struct.def @DecomposeProduct_1 { +// CHECK-NEXT: struct.member @high : !array.type<8 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: struct.member @low : !array.type<8 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: struct.member @u16s : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.member @bits_low : !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.member @bits_low$inputs_in : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.member @bits_high : !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.member @bits_high$inputs_in : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.def @compute(%[[VAL_44:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_45:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) -> !struct.type<@DecomposeProduct_1> attributes {function.allow_non_native_field_ops, function.allow_witness} { +// CHECK-NEXT: %[[VAL_46:[0-9a-zA-Z_\.]+]] = struct.new : <@DecomposeProduct_1> +// CHECK-NEXT: %[[VAL_47:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_48:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_49:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_50:[0-9a-zA-Z_\.]+]] = array.new : <8 x index> +// CHECK-NEXT: %[[VAL_51:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_52:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_53:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_54:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_55:[0-9a-zA-Z_\.]+]] = %[[VAL_53]] to %[[VAL_52]] step %[[VAL_54]] { +// CHECK-NEXT: %[[VAL_56:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_50]]{{\[}}%[[VAL_55]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_57:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_51]]{{\[}}%[[VAL_55]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_58:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[VAL_50]]{{\[}}%[[VAL_55]]] = %[[VAL_58]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_51]]{{\[}}%[[VAL_55]]] = %[[VAL_57]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_59:[0-9a-zA-Z_\.]+]] = array.new : <8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_60:[0-9a-zA-Z_\.]+]] = array.new : <8 x index> +// CHECK-NEXT: %[[VAL_61:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_62:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_63:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_64:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_65:[0-9a-zA-Z_\.]+]] = %[[VAL_63]] to %[[VAL_62]] step %[[VAL_64]] { +// CHECK-NEXT: %[[VAL_66:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_65]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_67:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_65]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_68:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[VAL_60]]{{\[}}%[[VAL_65]]] = %[[VAL_68]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_61]]{{\[}}%[[VAL_65]]] = %[[VAL_67]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_69:[0-9a-zA-Z_\.]+]] = array.new : <8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_70:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_71:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_72:[0-9a-zA-Z_\.]+]] = %[[VAL_69]], %[[VAL_73:[0-9a-zA-Z_\.]+]] = %[[VAL_59]], %[[VAL_74:[0-9a-zA-Z_\.]+]] = %[[VAL_70]]) : (!array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128">) -> (!array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_75:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_76:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_74]], %[[VAL_75]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_76]]) %[[VAL_72]], %[[VAL_73]], %[[VAL_74]] : !array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_77:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_78:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_79:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_80:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_81:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_44]]{{\[}}%[[VAL_80]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_82:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_83:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_45]]{{\[}}%[[VAL_82]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_84:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_81]], %[[VAL_83]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_85:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_47]]{{\[}}%[[VAL_85]]] = %[[VAL_84]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_86:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_87:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_47]]{{\[}}%[[VAL_86]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_88:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_89:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_87]], %[[VAL_88]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_90:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_48]]{{\[}}%[[VAL_90]]] = %[[VAL_89]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_91:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_92:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_48]]{{\[}}%[[VAL_91]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_93:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_94:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_78]]{{\[}}%[[VAL_93]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_95:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_78]]{{\[}}%[[VAL_95]]] = %[[VAL_92]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_96:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_97:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_50]]{{\[}}%[[VAL_96]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_98:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_51]]{{\[}}%[[VAL_96]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_99:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_100:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_97]], %[[VAL_99]] : index +// CHECK-NEXT: %[[VAL_101:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_102:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_100]], %[[VAL_101]] : index +// CHECK-NEXT: scf.if %[[VAL_102]] { +// CHECK-NEXT: %[[VAL_103:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_92]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_104:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_50]]{{\[}}%[[VAL_104]]] = %[[VAL_100]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_51]]{{\[}}%[[VAL_104]]] = %[[VAL_103]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_105:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_106:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_47]]{{\[}}%[[VAL_105]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_107:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_108:[0-9a-zA-Z_\.]+]] = felt.uintdiv %[[VAL_106]], %[[VAL_107]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_109:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_110:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_108]], %[[VAL_109]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_111:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_49]]{{\[}}%[[VAL_111]]] = %[[VAL_110]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_112:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_113:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_49]]{{\[}}%[[VAL_112]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_114:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_115:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_77]]{{\[}}%[[VAL_114]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_116:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_77]]{{\[}}%[[VAL_116]]] = %[[VAL_113]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_117:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_118:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_117]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_119:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_117]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_120:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_121:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_118]], %[[VAL_120]] : index +// CHECK-NEXT: %[[VAL_122:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_123:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_122]] : index +// CHECK-NEXT: scf.if %[[VAL_123]] { +// CHECK-NEXT: %[[VAL_124:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_113]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_125:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_60]]{{\[}}%[[VAL_125]]] = %[[VAL_121]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_61]]{{\[}}%[[VAL_125]]] = %[[VAL_124]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_126:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_127:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_79]], %[[VAL_126]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_77]], %[[VAL_78]], %[[VAL_127]] : !array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_high$inputs_in] = %[[VAL_71]]#0 : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_128:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_129:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_130:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_131:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_132:[0-9a-zA-Z_\.]+]] = %[[VAL_130]] to %[[VAL_129]] step %[[VAL_131]] { +// CHECK-NEXT: %[[VAL_133:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_132]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_134:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_132]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: array.write %[[VAL_128]]{{\[}}%[[VAL_132]]] = %[[VAL_134]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_high] = %[[VAL_128]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_low$inputs_in] = %[[VAL_71]]#1 : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_135:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_136:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_137:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_138:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_139:[0-9a-zA-Z_\.]+]] = %[[VAL_137]] to %[[VAL_136]] step %[[VAL_138]] { +// CHECK-NEXT: %[[VAL_140:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_50]]{{\[}}%[[VAL_139]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_141:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_51]]{{\[}}%[[VAL_139]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: array.write %[[VAL_135]]{{\[}}%[[VAL_139]]] = %[[VAL_141]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_low] = %[[VAL_135]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.writem %[[VAL_46]][@u16s] = %[[VAL_47]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_46]][@low] = %[[VAL_48]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_46]][@high] = %[[VAL_49]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.return %[[VAL_46]] : !struct.type<@DecomposeProduct_1> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_142:[0-9a-zA-Z_\.]+]]: !struct.type<@DecomposeProduct_1>, %[[VAL_143:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_144:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_145:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@high] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_146:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@low] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_147:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@u16s] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_148:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_low] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_149:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_low$inputs_in] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_150:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_high] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_151:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_high$inputs_in] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_152:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_153:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_154:[0-9a-zA-Z_\.]+]] = %[[VAL_152]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_155:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_156:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_154]], %[[VAL_155]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_156]]) %[[VAL_154]] : !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_157:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_158:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_159:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_143]]{{\[}}%[[VAL_158]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_160:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_161:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_144]]{{\[}}%[[VAL_160]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_162:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_159]], %[[VAL_161]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_163:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_164:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_147]]{{\[}}%[[VAL_163]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_164]], %[[VAL_162]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_165:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_166:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_146]]{{\[}}%[[VAL_165]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_167:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_168:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_149]]{{\[}}%[[VAL_167]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_168]], %[[VAL_166]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_169:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_170:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_145]]{{\[}}%[[VAL_169]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_171:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_172:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_151]]{{\[}}%[[VAL_171]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_172]], %[[VAL_170]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_173:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_174:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_147]]{{\[}}%[[VAL_173]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_175:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_176:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_146]]{{\[}}%[[VAL_175]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_177:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_178:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_179:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_145]]{{\[}}%[[VAL_178]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_180:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_177]], %[[VAL_179]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_181:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_176]], %[[VAL_180]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_174]], %[[VAL_181]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_182:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_183:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_157]], %[[VAL_182]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_183]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_184:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_185:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_186:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_187:[0-9a-zA-Z_\.]+]] = %[[VAL_185]] to %[[VAL_184]] step %[[VAL_186]] { +// CHECK-NEXT: %[[VAL_188:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_150]]{{\[}}%[[VAL_187]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_189:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_151]]{{\[}}%[[VAL_187]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_188]], %[[VAL_189]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_190:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_191:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_192:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_193:[0-9a-zA-Z_\.]+]] = %[[VAL_191]] to %[[VAL_190]] step %[[VAL_192]] { +// CHECK-NEXT: %[[VAL_194:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_148]]{{\[}}%[[VAL_193]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_195:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_149]]{{\[}}%[[VAL_193]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_194]], %[[VAL_195]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () +// CHECK-NEXT: } +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/constrain_array_of_pod.llzk b/test/Transforms/PodToScalar/constrain_array_of_pod.llzk new file mode 100644 index 000000000..44557bf67 --- /dev/null +++ b/test/Transforms/PodToScalar/constrain_array_of_pod.llzk @@ -0,0 +1,91 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @eq_array_pod( + %lhs: !array.type<2 x !Pair>, %rhs: !array.type<2 x !Pair> + ) attributes {function.allow_constraint} { + constrain.eq %lhs, %rhs : !array.type<2 x !Pair> + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @eq_array_pod( +// CHECK-SAME: %[[LHS_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[LHS_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, +// CHECK-SAME: %[[RHS_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[RHS_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type> +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: constrain.eq %[[LHS_X]], %[[RHS_X]] : !array.type<2 x index>, !array.type<2 x index> +// CHECK-NEXT: constrain.eq %[[LHS_Y]], %[[RHS_Y]] : !array.type<2 x !felt.type>, !array.type<2 x !felt.type> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @contains_pod( + %arr: !array.type<2 x !Pair>, %elem: !Pair + ) attributes {function.allow_constraint} { + constrain.in %arr, %elem : !array.type<2 x !Pair>, !Pair + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @contains_pod( +// CHECK-SAME: %[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, +// CHECK-SAME: %[[ELEM_X:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[ELEM_Y:[0-9a-zA-Z_\.]+]]: !felt.type +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[TRUE:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: %[[IDX:[0-9a-zA-Z_\.]+]] = llzk.nondet : index +// CHECK-NEXT: %[[DIM0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[DIM0]] : <2 x index> +// CHECK-NEXT: %[[GE0:[0-9a-zA-Z_\.]+]] = arith.cmpi sge, %[[IDX]], %[[C0]] : index +// CHECK-NEXT: constrain.eq %[[GE0]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[LT_LEN:[0-9a-zA-Z_\.]+]] = arith.cmpi slt, %[[IDX]], %[[LEN]] : index +// CHECK-NEXT: constrain.eq %[[LT_LEN]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[SEL_X:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_X]]{{\[}}%[[IDX]]] : <2 x index>, index +// CHECK-NEXT: constrain.eq %[[SEL_X]], %[[ELEM_X]] : index, index +// CHECK-NEXT: %[[SEL_Y:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_Y]]{{\[}}%[[IDX]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: constrain.eq %[[SEL_Y]], %[[ELEM_Y]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @contains_subarray( + %arr: !array.type<2,2 x !Pair>, %sub: !array.type<2 x !Pair> + ) attributes {function.allow_constraint} { + constrain.in %arr, %sub : !array.type<2,2 x !Pair>, !array.type<2 x !Pair> + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @contains_subarray( +// CHECK-SAME: %[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>, +// CHECK-SAME: %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x !felt.type>, +// CHECK-SAME: %[[SUB_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[SUB_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type> +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[TRUE:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: %[[IDX:[0-9a-zA-Z_\.]+]] = llzk.nondet : index +// CHECK-NEXT: %[[DIM0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[DIM0]] : <2,2 x index> +// CHECK-NEXT: %[[GE0:[0-9a-zA-Z_\.]+]] = arith.cmpi sge, %[[IDX]], %[[C0]] : index +// CHECK-NEXT: constrain.eq %[[GE0]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[LT_LEN:[0-9a-zA-Z_\.]+]] = arith.cmpi slt, %[[IDX]], %[[LEN]] : index +// CHECK-NEXT: constrain.eq %[[LT_LEN]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[SEL_X:[0-9a-zA-Z_\.]+]] = array.extract %[[ARR_X]]{{\[}}%[[IDX]]] : <2,2 x index> +// CHECK-NEXT: constrain.eq %[[SEL_X]], %[[SUB_X]] : !array.type<2 x index>, !array.type<2 x index> +// CHECK-NEXT: %[[SEL_Y:[0-9a-zA-Z_\.]+]] = array.extract %[[ARR_Y]]{{\[}}%[[IDX]]] : <2,2 x !felt.type> +// CHECK-NEXT: constrain.eq %[[SEL_Y]], %[[SUB_Y]] : !array.type<2 x !felt.type>, !array.type<2 x !felt.type> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/function_calls_with_pod.llzk b/test/Transforms/PodToScalar/function_calls_with_pod.llzk index 6a24ab974..549f771ff 100644 --- a/test/Transforms/PodToScalar/function_calls_with_pod.llzk +++ b/test/Transforms/PodToScalar/function_calls_with_pod.llzk @@ -159,11 +159,11 @@ module attributes {llzk.lang} { } } // CHECK-LABEL: module attributes {llzk.lang} { -// CHECK-NEXT: function.def @id_array_of_pod(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x !pod.type<[@x: index]>>) -> !array.type<2 x !pod.type<[@x: index]>> { -// CHECK-NEXT: function.return %[[VAL_0]] : !array.type<2 x !pod.type<[@x: index]>> +// CHECK-NEXT: function.def @id_array_of_pod(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>) -> !array.type<2 x index> { +// CHECK-NEXT: function.return %[[VAL_0]] : !array.type<2 x index> // CHECK-NEXT: } -// CHECK-NEXT: function.def @main(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !pod.type<[@x: index]>>) -> !array.type<2 x !pod.type<[@x: index]>> { -// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @id_array_of_pod(%[[VAL_1]]) : (!array.type<2 x !pod.type<[@x: index]>>) -> !array.type<2 x !pod.type<[@x: index]>> -// CHECK-NEXT: function.return %[[VAL_2]] : !array.type<2 x !pod.type<[@x: index]>> +// CHECK-NEXT: function.def @main(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>) -> !array.type<2 x index> { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @id_array_of_pod(%[[VAL_1]]) : (!array.type<2 x index>) -> !array.type<2 x index> +// CHECK-NEXT: function.return %[[VAL_2]] : !array.type<2 x index> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk b/test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk new file mode 100644 index 000000000..b3e69ddef --- /dev/null +++ b/test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk @@ -0,0 +1,16 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@lhs: index, @rhs: !felt.type]> +!PodArray = !array.type<2 x !Pair> +module attributes {llzk.lang} { + function.def @named_array_result(%arg: !PodArray) -> (!PodArray {function.res_name = "out"}) { + function.return %arg : !PodArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @named_array_result +// CHECK-SAME: (%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>) +// CHECK-SAME: -> (!array.type<2 x index> {function.res_name = "out.lhs"}, !array.type<2 x !felt.type> {function.res_name = "out.rhs"}) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2 x index>, !array.type<2 x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk b/test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk new file mode 100644 index 000000000..8e140cff2 --- /dev/null +++ b/test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk @@ -0,0 +1,28 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Inner = !pod.type<[@x: index]> +!Outer = !pod.type<[@a: !array.type<2 x !Inner>, @b: index]> +module attributes {llzk.lang} { + function.def @id( + %arg: !Outer {function.arg_name = "arg"} + ) -> (!Outer {function.res_name = "out"}) { + function.return %arg : !Outer + } + + function.def @main(%arg: !Outer) -> !Outer { + %res = function.call @id(%arg) : (!Outer) -> !Outer + function.return %res : !Outer + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @id( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index> {function.arg_name = "arg.a.x"}, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: index {function.arg_name = "arg.b"} +// CHECK-SAME: ) -> (!array.type<2 x index> {function.res_name = "out.a.x"}, index {function.res_name = "out.b"}) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2 x index>, index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @main(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: index) -> (!array.type<2 x index>, index) { +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]]:2 = function.call @id(%[[VAL_2]], %[[VAL_3]]) : (!array.type<2 x index>, index) -> (!array.type<2 x index>, index) +// CHECK-NEXT: function.return %[[VAL_4]]#0, %[[VAL_4]]#1 : !array.type<2 x index>, index +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk b/test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk new file mode 100644 index 000000000..214182f72 --- /dev/null +++ b/test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk @@ -0,0 +1,47 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#map = affine_map<()[s0] -> (s0 - 1)> +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !array.type<2 x !pod.type<[@x: !felt.type, @y: !felt.type]>> {column} + + function.def @compute(%idx: index) -> !struct.type<@S> { + %self = struct.new : !struct.type<@S> + function.return %self : !struct.type<@S> + } + + function.def @constrain(%self: !struct.type<@S>, %idx: index) { + %arr = struct.readm %self[@m] {()[%idx]} : + !struct.type<@S>, !array.type<2 x !pod.type<[@x: !felt.type, @y: !felt.type]>> + {tableOffset = #map} + %c0 = arith.constant 0 : index + %elt = array.read %arr[%c0] : + !array.type<2 x !pod.type<[@x: !felt.type, @y: !felt.type]>>, + !pod.type<[@x: !felt.type, @y: !felt.type]> + %x = pod.read %elt[@x] : !pod.type<[@x: !felt.type, @y: !felt.type]>, !felt.type + %y = pod.read %elt[@y] : !pod.type<[@x: !felt.type, @y: !felt.type]>, !felt.type + constrain.eq %x, %y : !felt.type + function.return + } + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 - 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-NEXT: struct.member @m_x : !array.type<2 x !felt.type> {column} +// CHECK-NEXT: struct.member @m_y : !array.type<2 x !felt.type> {column} +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@S>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@m_x] {(){{\[}}%[[VAL_3]]]} : <@S>, !array.type<2 x !felt.type> {tableOffset = #[[$M]]} +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@m_y] {(){{\[}}%[[VAL_3]]]} : <@S>, !array.type<2 x !felt.type> {tableOffset = #[[$M]]} +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_4]]{{\[}}%[[VAL_6]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_5]]{{\[}}%[[VAL_6]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_7]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/member_with_nested_pod_array.llzk b/test/Transforms/PodToScalar/member_with_nested_pod_array.llzk new file mode 100644 index 000000000..5c675fe47 --- /dev/null +++ b/test/Transforms/PodToScalar/member_with_nested_pod_array.llzk @@ -0,0 +1,63 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + + function.def @compute(%p: !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]>) + -> !struct.type<@S> { + %self = struct.new : !struct.type<@S> + struct.writem %self[@m] = %p : !struct.type<@S>, !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + %loaded = struct.readm %self[@m] : !struct.type<@S>, !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + %b = pod.read %loaded[@b] : !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]>, index + function.return %self : !struct.type<@S> + } + + function.def @constrain( + %self: !struct.type<@S>, %p: !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + ) { + %loaded = struct.readm %self[@m] : !struct.type<@S>, !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + %b = pod.read %loaded[@b] : !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]>, index + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-DAG: struct.member @m_a_x : !array.type<2 x index> +// CHECK-DAG: struct.member @m_a_y : !array.type<2 x !felt.type> +// CHECK-DAG: struct.member @m_b : index +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, %[[VAL_2:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-DAG: struct.writem %[[VAL_3]][@m_a_y] = %[[VAL_1]] : <@S>, !array.type<2 x !felt.type> +// CHECK-DAG: struct.writem %[[VAL_3]][@m_b] = %[[VAL_2]] : <@S>, index +// CHECK-DAG: struct.writem %[[VAL_3]][@m_a_x] = %[[VAL_0]] : <@S>, !array.type<2 x index> +// CHECK-DAG: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@m_a_y] : <@S>, !array.type<2 x !felt.type> +// CHECK-DAG: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@m_b] : <@S>, index +// CHECK-DAG: %[[VAL_6:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@m_a_x] : <@S>, !array.type<2 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_6]]{{\[}}%[[VAL_7]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_4]]{{\[}}%[[VAL_9]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_6]]{{\[}}%[[VAL_11]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_4]]{{\[}}%[[VAL_13]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: function.return %[[VAL_3]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_15:[0-9a-zA-Z_\.]+]]: !struct.type<@S>, %[[VAL_16:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_17:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, %[[VAL_18:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-DAG: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_15]][@m_a_y] : <@S>, !array.type<2 x !felt.type> +// CHECK-DAG: %[[VAL_20:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_15]][@m_b] : <@S>, index +// CHECK-DAG: %[[VAL_21:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_15]][@m_a_x] : <@S>, !array.type<2 x index> +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_21]]{{\[}}%[[VAL_22]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_24]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: %[[VAL_26:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_27:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_21]]{{\[}}%[[VAL_26]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_28:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_29:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_28]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk b/test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk new file mode 100644 index 000000000..ddcf94efe --- /dev/null +++ b/test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk @@ -0,0 +1,55 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + function.def @return_direct(%n: index) -> !Outer { + %p = pod.new()[%n] : !Outer + function.return %p : !Outer + } +} +// CHECK: #[[$MAP0:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @return_direct(%[[N0:[0-9a-zA-Z_\.]+]]: index) -> !array.type<#[[$MAP0]] x index> { +// CHECK-NEXT: %[[ARR0:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#[[$MAP0]] x index> +// CHECK-NEXT: function.return %[[ARR0]] : !array.type<#[[$MAP0]] x index> +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !Outer + + function.def @compute(%n: index) -> !struct.type<@S> { + %p = pod.new()[%n] : !Outer + %self = struct.new : !struct.type<@S> + struct.writem %self[@m] = %p : !struct.type<@S>, !Outer + function.return %self : !struct.type<@S> + } + + function.def @constrain(%self: !struct.type<@S>, %n: index) { + function.return + } + } +} +// CHECK: #[[$MAP1:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-NEXT: struct.member @m_items_x : !array.type<#[[$MAP1]] x index> +// CHECK-NEXT: function.def @compute(%[[N1:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[SELF:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-NEXT: %[[ARR1:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#[[$MAP1]] x index> +// CHECK-NEXT: struct.writem %[[SELF]][@m_items_x] = %[[ARR1]] : <@S>, !array.type<#[[$MAP1]] x index> +// CHECK-NEXT: function.return %[[SELF]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[SELFARG:[0-9a-zA-Z_\.]+]]: !struct.type<@S>, %[[N2:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/return_array_from_pod_field.llzk b/test/Transforms/PodToScalar/return_array_from_pod_field.llzk new file mode 100644 index 000000000..4e1a36407 --- /dev/null +++ b/test/Transforms/PodToScalar/return_array_from_pod_field.llzk @@ -0,0 +1,20 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index, @y: !felt.type]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + function.def @get_items(%p: !Outer) -> !array.type<#outer x !Item> { + %items = pod.read %p[@items] : !Outer, !array.type<#outer x !Item> + function.return %items : !array.type<#outer x !Item> + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @get_items( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type> +// CHECK-SAME: ) -> (!array.type<#[[$M]] x index>, !array.type<#[[$M]] x !felt.type>) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<#[[$M]] x index>, !array.type<#[[$M]] x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk b/test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk new file mode 100644 index 000000000..02a41875d --- /dev/null +++ b/test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk @@ -0,0 +1,50 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<2 x !Item>]> +module attributes {llzk.lang} { + function.def @return_unwritten_static() -> !Outer { + %p = pod.new : !Outer + function.return %p : !Outer + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @return_unwritten_static() -> !array.type<2 x index> { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: function.return %[[ARR]] : !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<2 x !Item>]> +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !Outer + + function.def @compute() -> !struct.type<@S> { + %p = pod.new : !Outer + %self = struct.new : !struct.type<@S> + struct.writem %self[@m] = %p : !struct.type<@S>, !Outer + function.return %self : !struct.type<@S> + } + + function.def @constrain(%self: !struct.type<@S>) { + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-NEXT: struct.member @m_items_x : !array.type<2 x index> +// CHECK-NEXT: function.def @compute() -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[SELF:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: struct.writem %[[SELF]][@m_items_x] = %[[ARR]] : <@S>, !array.type<2 x index> +// CHECK-NEXT: function.return %[[SELF]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[SELFARG:[0-9a-zA-Z_\.]+]]: !struct.type<@S>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk b/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk new file mode 100644 index 000000000..21689a0b9 --- /dev/null +++ b/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk @@ -0,0 +1,55 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#map = affine_map<()[s0] -> (s0)> +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @sink(%arr: !array.type) -> index { + %c0 = arith.constant 0 : index + %len = array.len %arr, %c0 : !array.type + function.return %len : index + } + + function.def @main(%arr: !array.type<#map x !Pair>) -> (!array.type, index) { + %cast = poly.unifiable_cast %arr : (!array.type<#map x !Pair>) -> !array.type + %len = function.call @sink(%cast) : (!array.type) -> index + function.return %cast, %len : !array.type, index + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @sink(%[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type, %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type) -> index { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[C0]] : +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @main(%[[ARG_X:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[ARG_Y:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> (!array.type, !array.type, index) { +// CHECK-NEXT: %[[CAST_X:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[ARG_X]] : (!array.type<#[[$M]] x index>) -> !array.type +// CHECK-NEXT: %[[CAST_Y:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[ARG_Y]] : (!array.type<#[[$M]] x !felt.type>) -> !array.type +// CHECK-NEXT: %[[CALL_LEN:[0-9a-zA-Z_\.]+]] = function.call @sink(%[[CAST_X]], %[[CAST_Y]]) : (!array.type, !array.type) -> index +// CHECK-NEXT: function.return %[[CAST_X]], %[[CAST_Y]], %[[CALL_LEN]] : !array.type, !array.type, index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#map = affine_map<()[s0] -> (s0)> +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + poly.template @CastToTypeVar { + poly.param @T_return : !poly.tvar<@T_return> + function.def @main(%arr: !array.type<#map x !Pair>) -> !poly.tvar<@T_return> { + %cast = poly.unifiable_cast %arr : (!array.type<#map x !Pair>) -> !poly.tvar<@T_return> + function.return %cast : !poly.tvar<@T_return> + } + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: poly.template @CastToTypeVar { +// CHECK-NEXT: poly.param @T_return : !poly.tvar<@T_return> +// CHECK-NEXT: function.def @main(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> !poly.tvar<@T_return> { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_0]] : (!array.type<#[[$M]] x index>) -> !poly.tvar<@T_return> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_2]] : (!poly.tvar<@T_return>) -> !poly.tvar<@T_return> +// CHECK-NEXT: function.return %[[VAL_3]] : !poly.tvar<@T_return> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk b/test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk new file mode 100644 index 000000000..57d8107a3 --- /dev/null +++ b/test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk @@ -0,0 +1,82 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + function.def @arg_case(%p: !Single, %new: index) -> index { + pod.write %p[@x] = %new : !Single, index + %x = pod.read %p[@x] : !Single, index + function.return %x : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @arg_case(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: function.return %[[VAL_1]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + function.def @mk(%old: index) -> !Single { + %p = pod.new : !Single + pod.write %p[@x] = %old : !Single, index + function.return %p : !Single + } + + function.def @call_case(%old: index, %new: index) -> index { + %p = function.call @mk(%old) : (index) -> !Single + pod.write %p[@x] = %new : !Single, index + %x = pod.read %p[@x] : !Single, index + function.return %x : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @mk(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: function.return %[[VAL_0]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_case(%[[VAL_1:[0-9a-zA-Z_\.]+]]: index, %[[VAL_2:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = function.call @mk(%[[VAL_1]]) : (index) -> index +// CHECK-NEXT: function.return %[[VAL_2]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + struct.def @Box { + struct.member @p : !Single + struct.member @out : index + + function.def @compute(%old: index, %new: index) -> !struct.type<@Box> { + %self = struct.new : !struct.type<@Box> + %p = pod.new : !Single + pod.write %p[@x] = %old : !Single, index + struct.writem %self[@p] = %p : !struct.type<@Box>, !Single + %loaded = struct.readm %self[@p] : !struct.type<@Box>, !Single + pod.write %loaded[@x] = %new : !Single, index + %x = pod.read %loaded[@x] : !Single, index + struct.writem %self[@out] = %x : !struct.type<@Box>, index + function.return %self : !struct.type<@Box> + } + + function.def @constrain(%self: !struct.type<@Box>, %old: index, %new: index) { + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Box { +// CHECK-NEXT: struct.member @p_x : index +// CHECK-NEXT: struct.member @out : index +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@Box> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@Box> +// CHECK-NEXT: struct.writem %[[VAL_2]][@p_x] = %[[VAL_0]] : <@Box>, index +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@p_x] : <@Box>, index +// CHECK-NEXT: struct.writem %[[VAL_2]][@out] = %[[VAL_1]] : <@Box>, index +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@Box> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@Box>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, %[[VAL_6:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk b/test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk new file mode 100644 index 000000000..100584416 --- /dev/null +++ b/test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk @@ -0,0 +1,25 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + function.def @whole_use_after_computed_write( + %lhs: !Single, %rhs: !Single, %a: index, %b: index + ) attributes {function.allow_constraint} { + %sum = arith.addi %a, %b : index + pod.write %lhs[@x] = %sum : !Single, index + constrain.eq %lhs, %rhs : !Single + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @whole_use_after_computed_write( +// CHECK-SAME: %[[UNUSED:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[RHS:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[A:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[B:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[SUM:[0-9a-zA-Z_\.]+]] = arith.addi %[[A]], %[[B]] : index +// CHECK-NEXT: constrain.eq %[[SUM]], %[[RHS]] : index, index +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: }