From d659eaffd3d6517d3c3de4cf005ae124f8a83ee4 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 18 Jun 2026 16:40:05 -0500 Subject: [PATCH 01/36] changelog --- changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml diff --git a/changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml b/changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml new file mode 100644 index 000000000..d62356e20 --- /dev/null +++ b/changelogs/unreleased/th__make_pod_to_scalar_split_arrays.yaml @@ -0,0 +1,2 @@ +changed: + - Update `llzk-pod-to-scalar` to split pods within arrays by creating multiple arrays. From d256d3296507a089122351491559887285408706 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 18 Jun 2026 16:38:24 -0500 Subject: [PATCH 02/36] add test case --- .../PodToScalar/circom_decomp_prod.llzk | 557 ++++++++++++++++++ 1 file changed, 557 insertions(+) create mode 100644 test/Transforms/PodToScalar/circom_decomp_prod.llzk diff --git a/test/Transforms/PodToScalar/circom_decomp_prod.llzk b/test/Transforms/PodToScalar/circom_decomp_prod.llzk new file mode 100644 index 000000000..9c2d0a87c --- /dev/null +++ b/test/Transforms/PodToScalar/circom_decomp_prod.llzk @@ -0,0 +1,557 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s 2>&1 | FileCheck --enable-var-scope %s + +!F = !felt.type<"bn128"> +module attributes {llzk.lang = "circom", llzk.main = !struct.type<@DecomposeProduct_1>} { + struct.def @Num2Bits_0 { + struct.member @out : !array.type<8 x !F> {llzk.pub} + function.def @compute(%arg0: !F) -> !struct.type<@Num2Bits_0> attributes {function.allow_non_native_field_ops} { + %self = struct.new : <@Num2Bits_0> + %nondet = llzk.nondet : !array.type<8 x !F> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %felt_const_1 = felt.const 1 : !F + %felt_const_0_0 = felt.const 0 : !F + %0:3 = scf.while (%arg1 = %felt_const_1, %arg2 = %felt_const_0, %arg3 = %felt_const_0_0) : (!F, !F, !F) -> (!F, !F, !F) { + %felt_const_8_1 = felt.const 8 : !F + %1 = bool.cmp lt(%arg3, %felt_const_8_1) : !F, !F + scf.condition(%1) %arg1, %arg2, %arg3 : !F, !F, !F + } do { + ^bb0(%arg1: !F, %arg2: !F, %arg3: !F): + %1 = felt.shr %arg0, %arg3 : !F, !F + %felt_const_1_1 = felt.const 1 : !F + %2 = felt.bit_and %1, %felt_const_1_1 : !F, !F + %3 = cast.toindex %arg3 : !F + array.write %nondet[%3] = %2 : <8 x !F>, !F + %4 = cast.toindex %arg3 : !F + %5 = array.read %nondet[%4] : <8 x !F>, !F + %6 = felt.mul %5, %felt_const_1 : !F, !F + %7 = felt.add %arg2, %6 : !F, !F + %8 = felt.add %arg1, %arg1 : !F, !F + %felt_const_1_2 = felt.const 1 : !F + %9 = felt.add %arg3, %felt_const_1_2 : !F, !F + scf.yield %8, %7, %9 : !F, !F, !F + } + struct.writem %self[@out] = %nondet : <@Num2Bits_0>, !array.type<8 x !F> + function.return %self : !struct.type<@Num2Bits_0> + } + function.def @constrain(%arg0: !struct.type<@Num2Bits_0>, %arg1: !F) attributes {function.allow_non_native_field_ops} { + %0 = struct.readm %arg0[@out] : <@Num2Bits_0>, !array.type<8 x !F> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %felt_const_1 = felt.const 1 : !F + %felt_const_0_0 = felt.const 0 : !F + %1:3 = scf.while (%arg2 = %felt_const_0, %arg3 = %felt_const_1, %arg4 = %felt_const_0_0) : (!F, !F, !F) -> (!F, !F, !F) { + %felt_const_8_1 = felt.const 8 : !F + %2 = bool.cmp lt(%arg4, %felt_const_8_1) : !F, !F + scf.condition(%2) %arg2, %arg3, %arg4 : !F, !F, !F + } do { + ^bb0(%arg2: !F, %arg3: !F, %arg4: !F): + %2 = cast.toindex %arg4 : !F + %3 = array.read %0[%2] : <8 x !F>, !F + %4 = cast.toindex %arg4 : !F + %5 = array.read %0[%4] : <8 x !F>, !F + %felt_const_1_1 = felt.const 1 : !F + %6 = felt.sub %5, %felt_const_1_1 : !F, !F + %7 = felt.mul %3, %6 : !F, !F + %felt_const_0_2 = felt.const 0 : !F + constrain.eq %7, %felt_const_0_2 : !F, !F + %8 = cast.toindex %arg4 : !F + %9 = array.read %0[%8] : <8 x !F>, !F + %10 = felt.mul %9, %felt_const_1 : !F, !F + %11 = felt.add %arg2, %10 : !F, !F + %12 = felt.add %arg3, %arg3 : !F, !F + %felt_const_1_3 = felt.const 1 : !F + %13 = felt.add %arg4, %felt_const_1_3 : !F, !F + scf.yield %11, %12, %13 : !F, !F, !F + } + constrain.eq %1#0, %arg1 : !F, !F + function.return + } + } + struct.def @DecomposeProduct_1 { + struct.member @high : !array.type<8 x !F> {llzk.pub} + struct.member @low : !array.type<8 x !F> {llzk.pub} + struct.member @u16s : !array.type<8 x !F> + struct.member @bits_low : !array.type<8 x !struct.type<@Num2Bits_0>> + struct.member @bits_low$inputs : !array.type<8 x !pod.type<[@in: !F]>> + struct.member @bits_high : !array.type<8 x !struct.type<@Num2Bits_0>> + struct.member @bits_high$inputs : !array.type<8 x !pod.type<[@in: !F]>> + function.def @compute(%arg0: !array.type<8 x !F>, %arg1: !array.type<8 x !F>) -> !struct.type<@DecomposeProduct_1> attributes {function.allow_non_native_field_ops} { + %self = struct.new : <@DecomposeProduct_1> + %nondet = llzk.nondet : !array.type<8 x !F> + %nondet_0 = llzk.nondet : !array.type<8 x !F> + %nondet_1 = llzk.nondet : !array.type<8 x !F> + %array = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %arg2 = %c0 to %c8 step %c1 { + %1 = array.read %array[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %c1_16 = arith.constant 1 : index + pod.write %1[@count] = %c1_16 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + array.write %array[%arg2] = %1 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } + %array_2 = array.new : <8 x !pod.type<[@in: !F]>> + %array_3 = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> + %c8_4 = arith.constant 8 : index + %c0_5 = arith.constant 0 : index + %c1_6 = arith.constant 1 : index + scf.for %arg2 = %c0_5 to %c8_4 step %c1_6 { + %1 = array.read %array_3[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %c1_16 = arith.constant 1 : index + pod.write %1[@count] = %c1_16 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + array.write %array_3[%arg2] = %1 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } + %array_7 = array.new : <8 x !pod.type<[@in: !F]>> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %0:3 = scf.while (%arg2 = %array_7, %arg3 = %array_2, %arg4 = %felt_const_0) : (!array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F) -> (!array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F) { + %felt_const_8_16 = felt.const 8 : !F + %1 = bool.cmp lt(%arg4, %felt_const_8_16) : !F, !F + scf.condition(%1) %arg2, %arg3, %arg4 : !array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F + } do { + ^bb0(%arg2: !array.type<8 x !pod.type<[@in: !F]>>, %arg3: !array.type<8 x !pod.type<[@in: !F]>>, %arg4: !F): + %1 = cast.toindex %arg4 : !F + %2 = array.read %arg0[%1] : <8 x !F>, !F + %3 = cast.toindex %arg4 : !F + %4 = array.read %arg1[%3] : <8 x !F>, !F + %5 = felt.mul %2, %4 : !F, !F + %6 = cast.toindex %arg4 : !F + array.write %nondet[%6] = %5 : <8 x !F>, !F + %7 = cast.toindex %arg4 : !F + %8 = array.read %nondet[%7] : <8 x !F>, !F + %felt_const_256 = felt.const 256 : !F + %9 = felt.umod %8, %felt_const_256 : !F, !F + %10 = cast.toindex %arg4 : !F + array.write %nondet_0[%10] = %9 : <8 x !F>, !F + %11 = cast.toindex %arg4 : !F + %12 = array.read %nondet_0[%11] : <8 x !F>, !F + %13 = cast.toindex %arg4 : !F + %14 = array.read %arg3[%13] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + pod.write %14[@in] = %12 : <[@in: !F]>, !F + %15 = cast.toindex %arg4 : !F + array.write %arg3[%15] = %14 : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %16 = cast.toindex %arg4 : !F + %17 = array.read %array[%16] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %18 = pod.read %17[@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c1_16 = arith.constant 1 : index + %19 = arith.subi %18, %c1_16 : index + pod.write %17[@count] = %19 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c0_17 = arith.constant 0 : index + %20 = arith.cmpi eq, %19, %c0_17 : index + scf.if %20 { + %37 = pod.read %14[@in] : <[@in: !F]>, !F + %38 = function.call @Num2Bits_0::@compute(%37) : (!F) -> !struct.type<@Num2Bits_0> + pod.write %17[@comp] = %38 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + %39 = cast.toindex %arg4 : !F + array.write %array[%39] = %17 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } else { + } + %21 = cast.toindex %arg4 : !F + %22 = array.read %nondet[%21] : <8 x !F>, !F + %felt_const_256_18 = felt.const 256 : !F + %23 = felt.uintdiv %22, %felt_const_256_18 : !F, !F + %felt_const_256_19 = felt.const 256 : !F + %24 = felt.umod %23, %felt_const_256_19 : !F, !F + %25 = cast.toindex %arg4 : !F + array.write %nondet_1[%25] = %24 : <8 x !F>, !F + %26 = cast.toindex %arg4 : !F + %27 = array.read %nondet_1[%26] : <8 x !F>, !F + %28 = cast.toindex %arg4 : !F + %29 = array.read %arg2[%28] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + pod.write %29[@in] = %27 : <[@in: !F]>, !F + %30 = cast.toindex %arg4 : !F + array.write %arg2[%30] = %29 : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %31 = cast.toindex %arg4 : !F + %32 = array.read %array_3[%31] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %33 = pod.read %32[@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c1_20 = arith.constant 1 : index + %34 = arith.subi %33, %c1_20 : index + pod.write %32[@count] = %34 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index + %c0_21 = arith.constant 0 : index + %35 = arith.cmpi eq, %34, %c0_21 : index + scf.if %35 { + %37 = pod.read %29[@in] : <[@in: !F]>, !F + %38 = function.call @Num2Bits_0::@compute(%37) : (!F) -> !struct.type<@Num2Bits_0> + pod.write %32[@comp] = %38 : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + %39 = cast.toindex %arg4 : !F + array.write %array_3[%39] = %32 : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + } else { + } + %felt_const_1 = felt.const 1 : !F + %36 = felt.add %arg4, %felt_const_1 : !F, !F + scf.yield %arg2, %arg3, %36 : !array.type<8 x !pod.type<[@in: !F]>>, !array.type<8 x !pod.type<[@in: !F]>>, !F + } + struct.writem %self[@bits_high$inputs] = %0#0 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %array_8 = array.new : <8 x !struct.type<@Num2Bits_0>> + %c8_9 = arith.constant 8 : index + %c0_10 = arith.constant 0 : index + %c1_11 = arith.constant 1 : index + scf.for %arg2 = %c0_10 to %c8_9 step %c1_11 { + %1 = array.read %array_3[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %2 = pod.read %1[@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + array.write %array_8[%arg2] = %2 : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + } + struct.writem %self[@bits_high] = %array_8 : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + struct.writem %self[@bits_low$inputs] = %0#1 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %array_12 = array.new : <8 x !struct.type<@Num2Bits_0>> + %c8_13 = arith.constant 8 : index + %c0_14 = arith.constant 0 : index + %c1_15 = arith.constant 1 : index + scf.for %arg2 = %c0_14 to %c8_13 step %c1_15 { + %1 = array.read %array[%arg2] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> + %2 = pod.read %1[@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> + array.write %array_12[%arg2] = %2 : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + } + struct.writem %self[@bits_low] = %array_12 : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + struct.writem %self[@u16s] = %nondet : <@DecomposeProduct_1>, !array.type<8 x !F> + struct.writem %self[@low] = %nondet_0 : <@DecomposeProduct_1>, !array.type<8 x !F> + struct.writem %self[@high] = %nondet_1 : <@DecomposeProduct_1>, !array.type<8 x !F> + function.return %self : !struct.type<@DecomposeProduct_1> + } + function.def @constrain(%arg0: !struct.type<@DecomposeProduct_1>, %arg1: !array.type<8 x !F>, %arg2: !array.type<8 x !F>) attributes {function.allow_non_native_field_ops} { + %0 = struct.readm %arg0[@high] : <@DecomposeProduct_1>, !array.type<8 x !F> + %1 = struct.readm %arg0[@low] : <@DecomposeProduct_1>, !array.type<8 x !F> + %2 = struct.readm %arg0[@u16s] : <@DecomposeProduct_1>, !array.type<8 x !F> + %3 = struct.readm %arg0[@bits_low] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + %4 = struct.readm %arg0[@bits_low$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %5 = struct.readm %arg0[@bits_high] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> + %6 = struct.readm %arg0[@bits_high$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !F]>> + %felt_const_8 = felt.const 8 : !F + %felt_const_0 = felt.const 0 : !F + %7 = scf.while (%arg3 = %felt_const_0) : (!F) -> !F { + %felt_const_8_3 = felt.const 8 : !F + %8 = bool.cmp lt(%arg3, %felt_const_8_3) : !F, !F + scf.condition(%8) %arg3 : !F + } do { + ^bb0(%arg3: !F): + %8 = cast.toindex %arg3 : !F + %9 = array.read %arg1[%8] : <8 x !F>, !F + %10 = cast.toindex %arg3 : !F + %11 = array.read %arg2[%10] : <8 x !F>, !F + %12 = felt.mul %9, %11 : !F, !F + %13 = cast.toindex %arg3 : !F + %14 = array.read %2[%13] : <8 x !F>, !F + constrain.eq %14, %12 : !F, !F + %15 = cast.toindex %arg3 : !F + %16 = array.read %1[%15] : <8 x !F>, !F + %17 = cast.toindex %arg3 : !F + %18 = array.read %4[%17] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %19 = pod.read %18[@in] : <[@in: !F]>, !F + constrain.eq %19, %16 : !F, !F + %20 = cast.toindex %arg3 : !F + %21 = array.read %0[%20] : <8 x !F>, !F + %22 = cast.toindex %arg3 : !F + %23 = array.read %6[%22] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %24 = pod.read %23[@in] : <[@in: !F]>, !F + constrain.eq %24, %21 : !F, !F + %25 = cast.toindex %arg3 : !F + %26 = array.read %2[%25] : <8 x !F>, !F + %27 = cast.toindex %arg3 : !F + %28 = array.read %1[%27] : <8 x !F>, !F + %felt_const_256 = felt.const 256 : !F + %29 = cast.toindex %arg3 : !F + %30 = array.read %0[%29] : <8 x !F>, !F + %31 = felt.mul %felt_const_256, %30 : !F, !F + %32 = felt.add %28, %31 : !F, !F + constrain.eq %26, %32 : !F, !F + %felt_const_1 = felt.const 1 : !F + %33 = felt.add %arg3, %felt_const_1 : !F, !F + scf.yield %33 : !F + } + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %arg3 = %c0 to %c8 step %c1 { + %8 = array.read %5[%arg3] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + %9 = array.read %6[%arg3] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %10 = pod.read %9[@in] : <[@in: !F]>, !F + function.call @Num2Bits_0::@constrain(%8, %10) : (!struct.type<@Num2Bits_0>, !F) -> () + } + %c8_0 = arith.constant 8 : index + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + scf.for %arg3 = %c0_1 to %c8_0 step %c1_2 { + %8 = array.read %3[%arg3] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> + %9 = array.read %4[%arg3] : <8 x !pod.type<[@in: !F]>>, !pod.type<[@in: !F]> + %10 = pod.read %9[@in] : <[@in: !F]>, !F + function.call @Num2Bits_0::@constrain(%8, %10) : (!struct.type<@Num2Bits_0>, !F) -> () + } + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang = "circom", llzk.main = !struct.type<@DecomposeProduct_1>} { +// CHECK-NEXT: struct.def @Num2Bits_0 { +// CHECK-NEXT: struct.member @out : !array.type<8 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) -> !struct.type<@Num2Bits_0> attributes {function.allow_non_native_field_ops, function.allow_witness} { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Num2Bits_0> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_7:[0-9a-zA-Z_\.]+]] = %[[VAL_4]], %[[VAL_8:[0-9a-zA-Z_\.]+]] = %[[VAL_3]], %[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_5]]) : (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) -> (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_9]], %[[VAL_10]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_11]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_12:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_13:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_14:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = felt.shr %[[VAL_0]], %[[VAL_14]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = felt.bit_and %[[VAL_15]], %[[VAL_16]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_14]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_2]]{{\[}}%[[VAL_18]]] = %[[VAL_17]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_14]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_2]]{{\[}}%[[VAL_19]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_20]], %[[VAL_4]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_13]], %[[VAL_21]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_12]], %[[VAL_12]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_14]], %[[VAL_24]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_23]], %[[VAL_22]], %[[VAL_25]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_1]][@out] = %[[VAL_2]] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_26:[0-9a-zA-Z_\.]+]]: !struct.type<@Num2Bits_0>, %[[VAL_27:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_28:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_26]][@out] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_29:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_30:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_31:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_32:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_33:[0-9a-zA-Z_\.]+]] = %[[VAL_29]], %[[VAL_34:[0-9a-zA-Z_\.]+]] = %[[VAL_30]], %[[VAL_35:[0-9a-zA-Z_\.]+]] = %[[VAL_31]]) : (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) -> (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_36:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_37:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_35]], %[[VAL_36]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_37]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_38:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_39:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_40:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_41:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_40]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_42:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_28]]{{\[}}%[[VAL_41]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_43:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_40]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_44:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_28]]{{\[}}%[[VAL_43]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_45:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_46:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_44]], %[[VAL_45]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_47:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_42]], %[[VAL_46]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_48:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_47]], %[[VAL_48]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_49:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_40]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_50:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_28]]{{\[}}%[[VAL_49]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_51:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_50]], %[[VAL_30]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_52:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_38]], %[[VAL_51]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_53:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_39]], %[[VAL_39]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_54:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_55:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_40]], %[[VAL_54]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_52]], %[[VAL_53]], %[[VAL_55]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: constrain.eq %[[VAL_32]]#0, %[[VAL_27]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: struct.def @DecomposeProduct_1 { +// CHECK-NEXT: struct.member @high : !array.type<8 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: struct.member @low : !array.type<8 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: struct.member @u16s : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.member @bits_low : !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.member @bits_low$inputs : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: struct.member @bits_high : !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.member @bits_high$inputs : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: function.def @compute(%[[VAL_56:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_57:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) -> !struct.type<@DecomposeProduct_1> attributes {function.allow_non_native_field_ops, function.allow_witness} { +// CHECK-NEXT: %[[VAL_58:[0-9a-zA-Z_\.]+]] = struct.new : <@DecomposeProduct_1> +// CHECK-NEXT: %[[VAL_59:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_60:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_61:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_62:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> +// CHECK-NEXT: %[[VAL_63:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_64:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_65:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_66:[0-9a-zA-Z_\.]+]] = %[[VAL_64]] to %[[VAL_63]] step %[[VAL_65]] { +// CHECK-NEXT: %[[VAL_67:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_62]]{{\[}}%[[VAL_66]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_68:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: pod.write %[[VAL_67]][@count] = %[[VAL_68]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index +// CHECK-NEXT: array.write %[[VAL_62]]{{\[}}%[[VAL_66]]] = %[[VAL_67]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_69:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: %[[VAL_70:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> +// CHECK-NEXT: %[[VAL_71:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_72:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_73:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_74:[0-9a-zA-Z_\.]+]] = %[[VAL_72]] to %[[VAL_71]] step %[[VAL_73]] { +// CHECK-NEXT: %[[VAL_75:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_70]]{{\[}}%[[VAL_74]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_76:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: pod.write %[[VAL_75]][@count] = %[[VAL_76]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index +// CHECK-NEXT: array.write %[[VAL_70]]{{\[}}%[[VAL_74]]] = %[[VAL_75]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_77:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: %[[VAL_78:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_79:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_80:[0-9a-zA-Z_\.]+]] = %[[VAL_77]], %[[VAL_81:[0-9a-zA-Z_\.]+]] = %[[VAL_69]], %[[VAL_82:[0-9a-zA-Z_\.]+]] = %[[VAL_78]]) : (!array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128">) -> (!array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_83:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_84:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_82]], %[[VAL_83]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_85:[0-9a-zA-Z_\.]+]]: !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, %[[VAL_86:[0-9a-zA-Z_\.]+]]: !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, %[[VAL_87:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_88:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_89:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_56]]{{\[}}%[[VAL_88]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_90:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_91:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_57]]{{\[}}%[[VAL_90]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_92:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_89]], %[[VAL_91]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_93:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_59]]{{\[}}%[[VAL_93]]] = %[[VAL_92]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_94:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_95:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_59]]{{\[}}%[[VAL_94]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_96:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_97:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_95]], %[[VAL_96]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_98:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_60]]{{\[}}%[[VAL_98]]] = %[[VAL_97]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_99:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_100:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_99]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_101:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_102:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_86]]{{\[}}%[[VAL_101]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: pod.write %[[VAL_102]][@in] = %[[VAL_100]] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_103:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_86]]{{\[}}%[[VAL_103]]] = %[[VAL_102]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: %[[VAL_104:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_105:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_62]]{{\[}}%[[VAL_104]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_106:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_105]][@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index +// CHECK-NEXT: %[[VAL_107:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_108:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_106]], %[[VAL_107]] : index +// CHECK-NEXT: pod.write %[[VAL_105]][@count] = %[[VAL_108]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index +// CHECK-NEXT: %[[VAL_109:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_110:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_108]], %[[VAL_109]] : index +// CHECK-NEXT: scf.if %[[VAL_110]] { +// CHECK-NEXT: %[[VAL_111:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_100]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> +// CHECK-NEXT: pod.write %[[VAL_105]][@comp] = %[[VAL_111]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_112:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_62]]{{\[}}%[[VAL_112]]] = %[[VAL_105]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: } else { +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_113:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_114:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_59]]{{\[}}%[[VAL_113]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_115:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_116:[0-9a-zA-Z_\.]+]] = felt.uintdiv %[[VAL_114]], %[[VAL_115]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_117:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_118:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_116]], %[[VAL_117]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_119:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_61]]{{\[}}%[[VAL_119]]] = %[[VAL_118]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_120:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_121:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_120]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_122:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_123:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_85]]{{\[}}%[[VAL_122]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: pod.write %[[VAL_123]][@in] = %[[VAL_121]] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_124:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_85]]{{\[}}%[[VAL_124]]] = %[[VAL_123]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: %[[VAL_125:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_126:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_70]]{{\[}}%[[VAL_125]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_127:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_126]][@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index +// CHECK-NEXT: %[[VAL_128:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_129:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_127]], %[[VAL_128]] : index +// CHECK-NEXT: pod.write %[[VAL_126]][@count] = %[[VAL_129]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index +// CHECK-NEXT: %[[VAL_130:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_131:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_129]], %[[VAL_130]] : index +// CHECK-NEXT: scf.if %[[VAL_131]] { +// CHECK-NEXT: %[[VAL_132:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_121]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> +// CHECK-NEXT: pod.write %[[VAL_126]][@comp] = %[[VAL_132]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_133:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_70]]{{\[}}%[[VAL_133]]] = %[[VAL_126]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: } else { +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_134:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_135:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_87]], %[[VAL_134]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_85]], %[[VAL_86]], %[[VAL_135]] : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_high$inputs] = %[[VAL_79]]#0 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: %[[VAL_136:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_137:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_138:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_139:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_140:[0-9a-zA-Z_\.]+]] = %[[VAL_138]] to %[[VAL_137]] step %[[VAL_139]] { +// CHECK-NEXT: %[[VAL_141:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_70]]{{\[}}%[[VAL_140]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_142:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_141]][@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: array.write %[[VAL_136]]{{\[}}%[[VAL_140]]] = %[[VAL_142]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_high] = %[[VAL_136]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_low$inputs] = %[[VAL_79]]#1 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: %[[VAL_143:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_144:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_145:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_146:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_147:[0-9a-zA-Z_\.]+]] = %[[VAL_145]] to %[[VAL_144]] step %[[VAL_146]] { +// CHECK-NEXT: %[[VAL_148:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_62]]{{\[}}%[[VAL_147]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_149:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_148]][@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: array.write %[[VAL_143]]{{\[}}%[[VAL_147]]] = %[[VAL_149]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: } +// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_low] = %[[VAL_143]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.writem %[[VAL_58]][@u16s] = %[[VAL_59]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_58]][@low] = %[[VAL_60]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_58]][@high] = %[[VAL_61]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.return %[[VAL_58]] : !struct.type<@DecomposeProduct_1> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_150:[0-9a-zA-Z_\.]+]]: !struct.type<@DecomposeProduct_1>, %[[VAL_151:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_152:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_153:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@high] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_154:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@low] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_155:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@u16s] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_156:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_low] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_157:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_low$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: %[[VAL_158:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_high] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_159:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_high$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: %[[VAL_160:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_161:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_162:[0-9a-zA-Z_\.]+]] = %[[VAL_160]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_163:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_164:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_162]], %[[VAL_163]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_164]]) %[[VAL_162]] : !felt.type<"bn128"> +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_165:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_166:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_167:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_151]]{{\[}}%[[VAL_166]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_168:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_169:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_152]]{{\[}}%[[VAL_168]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_170:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_167]], %[[VAL_169]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_171:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_172:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_155]]{{\[}}%[[VAL_171]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_172]], %[[VAL_170]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_173:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_174:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_154]]{{\[}}%[[VAL_173]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_175:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_176:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_157]]{{\[}}%[[VAL_175]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: %[[VAL_177:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_176]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_177]], %[[VAL_174]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_178:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_179:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_153]]{{\[}}%[[VAL_178]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_180:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_181:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_159]]{{\[}}%[[VAL_180]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: %[[VAL_182:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_181]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_182]], %[[VAL_179]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_183:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_184:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_155]]{{\[}}%[[VAL_183]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_185:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_186:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_154]]{{\[}}%[[VAL_185]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_187:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_188:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_189:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_153]]{{\[}}%[[VAL_188]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_190:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_187]], %[[VAL_189]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_191:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_186]], %[[VAL_190]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_184]], %[[VAL_191]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_192:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_193:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_165]], %[[VAL_192]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_193]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_194:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_195:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_196:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_197:[0-9a-zA-Z_\.]+]] = %[[VAL_195]] to %[[VAL_194]] step %[[VAL_196]] { +// CHECK-NEXT: %[[VAL_198:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_158]]{{\[}}%[[VAL_197]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_199:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_159]]{{\[}}%[[VAL_197]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: %[[VAL_200:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_199]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> +// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_198]], %[[VAL_200]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[VAL_201:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_202:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_203:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_204:[0-9a-zA-Z_\.]+]] = %[[VAL_202]] to %[[VAL_201]] step %[[VAL_203]] { +// CHECK-NEXT: %[[VAL_205:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_156]]{{\[}}%[[VAL_204]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_206:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_157]]{{\[}}%[[VAL_204]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> +// CHECK-NEXT: %[[VAL_207:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_206]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> +// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_205]], %[[VAL_207]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () +// CHECK-NEXT: } +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From f6e898ef79ca8b91d20dbb381f9e213acba46f93 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 18 Jun 2026 17:54:46 -0500 Subject: [PATCH 03/36] split pods within arrays to multiple arrays --- .../POD/Transforms/PodToScalarPass.cpp | 802 ++++++++++++++++-- .../LLZKTransformationPassPipelines.cpp | 4 +- .../PodToScalar/circom_decomp_prod.llzk | 454 +++++----- .../PodToScalar/function_calls_with_pod.llzk | 10 +- 4 files changed, 958 insertions(+), 312 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index a44b9af65..9e1cc1847 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -15,26 +15,27 @@ /// 0. Scan to find `llzk.nondet` ops that allocate uninitialized pods and replace them with /// an equivalent `pod.new` /// -/// 1. Run a dialect conversion that replaces `PodType` struct members with one scalar member per -/// record and remembers how each original member was split. +/// 1. Run a dialect conversion that replaces direct `PodType` struct members with one scalar +/// member per record, replaces arrays whose element type is a POD with one parallel array per +/// scalar leaf record, and remembers how each original member was split. /// /// 2. Run a dialect conversion that does the following: /// /// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1 -/// so they instead perform scalar reads and writes from the new members. The transformation is -/// local to the current op. Therefore, when replacing the `MemberReadOp` a new pod is -/// created locally and all uses of the `MemberReadOp` are replaced with the new pod Value, -/// then each scalar member read is followed by scalar write into the new pod. Similarly, -/// when replacing a `MemberWriteOp`, each element in the pod operand needs a scalar read -/// from the pod followed by a scalar write to the new member. Making only local changes -/// keeps this step simple and later steps will optimize. +/// so they instead perform reads and writes on the new scalar or parallel-array members. The +/// transformation is local to the current op. Therefore, when replacing the `MemberReadOp` a +/// new pod is created locally and all uses of the `MemberReadOp` are replaced with the new +/// pod Value, then each scalar member read is followed by scalar write into the new pod. +/// Similarly, when replacing a `MemberWriteOp`, each element in the pod operand needs a +/// scalar read from the pod followed by a scalar write to the new member. Making only local +/// changes keeps this step simple and later steps will optimize. /// /// - Remove optional initialization from `NewPodOp` and instead insert a list of `WritePodOp` /// immediately following. /// -/// - Split pods to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary -/// create/read/write ops so the changes are as local as possible (just as described for -/// `MemberReadOp` and `MemberWriteOp`) +/// - Split arrays whose element type is a POD into parallel arrays in `array.*`, +/// `FuncDefOp`, `CallOp`, and `ReturnOp`, then split remaining direct POD values to scalars +/// in `FuncDefOp`, `CallOp`, and `ReturnOp`. /// /// 3. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the /// access can be modeled as an SSA value flowing through the region boundary. This puts the @@ -62,6 +63,8 @@ //===----------------------------------------------------------------------===// #include "llzk/Dialect/Array/IR/Dialect.h" +#include "llzk/Dialect/Array/IR/Ops.h" +#include "llzk/Dialect/Array/IR/Types.h" #include "llzk/Dialect/Bool/IR/Dialect.h" #include "llzk/Dialect/Cast/IR/Dialect.h" #include "llzk/Dialect/Constrain/IR/Dialect.h" @@ -86,6 +89,7 @@ #include "llzk/Util/Walk.h" #include +#include #include #include #include @@ -103,6 +107,7 @@ namespace llzk::pod { using namespace mlir; using namespace llzk; +using namespace llzk::array; using namespace llzk::pod; using namespace llzk::function; using namespace llzk::component; @@ -216,6 +221,107 @@ splitPodType(TypeCollection types, SmallVector *originalIdxToSize = null return collect; } +/// Visit each non-POD leaf record in `podTy`, providing its record-name chain and leaf type. +template +static void forEachPodLeaf(PodType podTy, SmallVectorImpl &recordChain, Fn &&callback) { + for (RecordAttr record : podTy.getRecords()) { + recordChain.push_back(record.getName()); + if (PodType nestedPodTy = dyn_cast(record.getType())) { + forEachPodLeaf(nestedPodTy, recordChain, callback); + } else { + callback(RecordChain(recordChain), record.getType()); + } + recordChain.pop_back(); + } +} + +/// If the given ArrayType has a POD element type, return it. +inline static ArrayType splittablePodArray(ArrayType at) { + return isa(at.getElementType()) ? at : nullptr; +} + +/// If the given Type is an ArrayType with a POD element type, return it. +inline static ArrayType splittablePodArray(Type t) { + if (ArrayType at = dyn_cast(t)) { + return splittablePodArray(at); + } + return nullptr; +} + +/// Return `true` iff any type in the range is an array whose element type is a POD. +inline static bool containsSplittablePodArrayType(ArrayRef types) { + return llvm::any_of(types, [](Type t) { return splittablePodArray(t); }); +} + +/// Return `true` iff any type in the range is an array whose element type is a POD. +template static bool containsSplittablePodArrayType(ValueTypeRange types) { + return llvm::any_of(types, [](Type t) { return splittablePodArray(t); }); +} + +/// If `t` is an array with POD element type, append one parallel array type for each POD leaf. +static size_t splitPodArrayTypeTo( + Type t, SmallVectorImpl &collect, SmallVector *splitIds = nullptr +) { + if (ArrayType at = splittablePodArray(t)) { + auto podTy = llvm::cast(at.getElementType()); + SmallVector recordChain; + size_t originalSize = collect.size(); + forEachPodLeaf(podTy, recordChain, [&](RecordChain id, Type leafType) { + collect.push_back(at.cloneWith(leafType)); + if (splitIds) { + splitIds->push_back(std::move(id)); + } + }); + return collect.size() - originalSize; + } + + collect.push_back(t); + return 1; +} + +/// For each Type in the given input collection, call `splitPodArrayTypeTo(Type,...)`. +template +inline void splitPodArrayTypeTo( + TypeCollection types, SmallVectorImpl &collect, SmallVector *originalIdxToSize +) { + for (Type t : types) { + size_t count = splitPodArrayTypeTo(t, collect); + if (originalIdxToSize) { + originalIdxToSize->push_back(count); + } + } +} + +/// Return a list such that each non-array POD type is kept as-is, while each array-of-POD type is +/// replaced by one parallel array type per non-POD leaf record in the element POD. +template +inline SmallVector +splitPodArrayType(TypeCollection types, SmallVector *originalIdxToSize = nullptr) { + SmallVector collect; + splitPodArrayTypeTo(types, collect, originalIdxToSize); + return collect; +} + +/// Return the suffixes to append to a function arg/result name when splitting an array of PODs. +static SmallVector getSplitPodArrayRecordNameSuffixes(Type type) { + SmallVector suffixes; + if (ArrayType at = splittablePodArray(type)) { + SmallVector splitIds; + SmallVector ignoredTypes; + splitPodArrayTypeTo(at, ignoredTypes, &splitIds); + suffixes.reserve(splitIds.size()); + for (const RecordChain &id : splitIds) { + std::string suffix; + llvm::raw_string_ostream os(suffix); + for (StringAttr recordName : id.nameList) { + os << '.' << recordName.getValue(); + } + suffixes.push_back(std::move(suffix)); + } + } + return suffixes; +} + /// Create a `pod.read` for one record of `podRef`. inline static ReadPodOp genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &rewriter) { @@ -230,6 +336,54 @@ genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuild return rewriter.create(loc, podRef, recordName, value); } +/// Return the single converted value from a 1:N adaptor range. +inline static Value getSingleConvertedValue(ValueRange values) { + assert(values.size() == 1 && "expected a 1:1 converted value range"); + return values.front(); +} + +/// Flatten a range of converted value ranges into a single list of values. +template +static SmallVector flattenConvertedValues(RangeOfRanges ranges) { + SmallVector values; + for (ValueRange range : ranges) { + llvm::append_range(values, range); + } + return values; +} + +/// Read a nested POD leaf by following each record name in `recordChain`. +static Value +genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &rewriter) { + Value value = podRef; + for (StringAttr attr : recordChain.nameList) { + value = genRead(loc, value, attr, rewriter); + } + return value; +} + +/// Reconstruct a POD record from the leaf values collected while splitting nested accesses. +static Value rebuildFlattenedPodRecord( + Location loc, Type recordType, SmallVectorImpl &recordChain, + const DenseMap &leafValues, ConversionPatternRewriter &rewriter +) { + if (PodType nestedPodTy = dyn_cast(recordType)) { + NewPodOp nestedPod = rewriter.create(loc, nestedPodTy); + for (RecordAttr record : nestedPodTy.getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, rewriter); + genWrite(loc, nestedPod, record.getName(), recordValue, rewriter); + recordChain.pop_back(); + } + return nestedPod; + } + + auto it = leafValues.find(RecordChain(recordChain)); + assert(it != leafValues.end() && "missing flattened POD leaf value"); + return it->second; +} + /// Return the suffixes to append to a function arg/result name when splitting the given type. static SmallVector getSplitRecordNameSuffixes(Type type) { SmallVector suffixes; @@ -341,28 +495,16 @@ static void flattenPodMemberIntoLeaves( LocalMemberReplacementMap &localRepMapRef, SymbolTable &structSymbolTable, ConversionPatternRewriter &rewriter ) { - for (RecordAttr record : podTy.getRecords()) { - recordChain.push_back(record.getName()); - if (PodType nestedPodTy = dyn_cast(record.getType())) { - flattenPodMemberIntoLeaves( - originalMember, nestedPodTy, recordChain, localRepMapRef, structSymbolTable, rewriter - ); - recordChain.pop_back(); - continue; - } - + forEachPodLeaf(podTy, recordChain, [&](RecordChain id, Type ty) { StringAttr name = getFlattenedMemberName( - originalMember.getContext(), originalMember.getSymNameAttr(), recordChain + originalMember.getContext(), originalMember.getSymNameAttr(), id.nameList ); - Type ty = record.getType(); MemberDefOp newMember = rewriter.create( originalMember.getLoc(), name, ty, originalMember.getSignal(), originalMember.getColumn() ); newMember.setPublicAttr(originalMember.hasPublicAttr()); - localRepMapRef[RecordChain(recordChain)] = - std::make_pair(structSymbolTable.insert(newMember), ty); - recordChain.pop_back(); - } + localRepMapRef[id] = std::make_pair(structSymbolTable.insert(newMember), ty); + }); } /// Split a pod-typed struct member definition into one scalar member definition per POD record. @@ -398,20 +540,527 @@ class SplitPodInMemberDefOp : public OpConversionPattern { } }; -/// Replace `PodType` struct members with scalar members. +/// Split an array-of-POD struct member definition into one parallel array member per POD leaf. +class SplitPodArrayInMemberDefOp : public OpConversionPattern { + SymbolTableCollection &tables; + MemberReplacementMap &repMapRef; + +public: + SplitPodArrayInMemberDefOp( + MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap) {} + + inline static bool legal(MemberDefOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); } + + void + rewrite(MemberDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructDefOp inStruct = op->getParentOfType(); + assert(inStruct); + LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()]; + + ArrayType arrTy = llvm::cast(adaptor.getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + StringAttr name = getFlattenedMemberName(op.getContext(), op.getSymNameAttr(), id.nameList); + MemberDefOp newMember = rewriter.create( + op.getLoc(), name, splitType, op.getSignal(), op.getColumn() + ); + newMember.setPublicAttr(op.hasPublicAttr()); + localRepMapRef[id] = std::make_pair(structSymbolTable.insert(newMember), splitType); + } + rewriter.eraseOp(op); + } +}; + +/// Replace direct `PodType` struct members with scalar members and arrays-of-POD with parallel +/// array members named after the corresponding POD leaf. static LogicalResult step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) { MLIRContext *ctx = modOp.getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx, symTables, memberRepMap); + patterns.add(ctx, symTables, memberRepMap); ConversionTarget target(*ctx); baseTargetSetup(target); - target.addDynamicallyLegalOp(SplitPodInMemberDefOp::legal); + target.addDynamicallyLegalOp([](MemberDefOp op) { + return SplitPodInMemberDefOp::legal(op) && SplitPodArrayInMemberDefOp::legal(op); + }); + + LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split pod-type and array-of-pod members\n";); + return applyFullConversion(modOp, target, std::move(patterns)); +} + +/// Type converter that replaces each array-of-POD type with one parallel array type per POD leaf. +class PodArrayTypeConverter : public TypeConverter { +public: + PodArrayTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion( + [](ArrayType arrTy, SmallVectorImpl &results) -> std::optional { + if (!splittablePodArray(arrTy)) { + return std::nullopt; + } + splitPodArrayTypeTo(arrTy, results); + return success(); + } + ); + } +}; + +/// Split `llzk.nondet` of array-of-POD type into one `llzk.nondet` per parallel leaf array. +class SplitPodArrayNonDetOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(NonDetOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult match(NonDetOp op) const override { return failure(legal(op)); } + + void rewrite(NonDetOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector splitTypes; + splitPodArrayTypeTo(op.getType(), splitTypes); + SmallVector replacements; + replacements.reserve(splitTypes.size()); + for (Type splitType : splitTypes) { + replacements.push_back(rewriter.create(op.getLoc(), splitType)); + } + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + } +}; + +/// Split `array.new` of array-of-POD type into one `array.new` per parallel leaf array. +class SplitPodArrayCreateArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(CreateArrayOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult matchAndRewrite( + CreateArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + ArrayType arrTy = llvm::cast(op.getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector replacements; + replacements.reserve(splitTypes.size()); + DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr(); + if (isNullOrEmpty(numDimsPerMap)) { + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + SmallVector splitElements; + splitElements.reserve(adaptor.getElements().size()); + for (ValueRange elementRange : adaptor.getElements()) { + Value element = getSingleConvertedValue(elementRange); + splitElements.push_back(genReadAlongPath(op.getLoc(), element, id, rewriter)); + } + replacements.push_back(rewriter.create( + op.getLoc(), llvm::cast(splitType), splitElements + )); + } + } else { + SmallVector> mapOperandStorage; + SmallVector mapOperands; + mapOperandStorage.reserve(adaptor.getMapOperands().size()); + mapOperands.reserve(adaptor.getMapOperands().size()); + for (ArrayRef mapOperandGroup : adaptor.getMapOperands()) { + mapOperandStorage.push_back(flattenConvertedValues(mapOperandGroup)); + } + for (const SmallVector &values : mapOperandStorage) { + mapOperands.push_back(values); + } + for (Type splitType : splitTypes) { + replacements.push_back(rewriter.create( + op.getLoc(), llvm::cast(splitType), mapOperands, numDimsPerMap + )); + } + } + + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + +/// Split `array.read` from an array-of-POD into scalar leaf reads plus local POD reconstruction. +class SplitPodArrayReadArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ReadArrayOp op) { return !splittablePodArray(op.getArrRefType()); } + + LogicalResult matchAndRewrite( + ReadArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + ArrayType arrTy = op.getArrRefType(); + PodType podTy = llvm::cast(arrTy.getElementType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + NewPodOp pod = rewriter.create(op.getLoc(), podTy); + DenseMap leafValues; + for (auto [id, splitArrRange, splitType] : + llvm::zip_equal(splitIds, adaptor.getArrRef(), splitTypes)) { + auto splitArrTy = llvm::cast(splitType); + Value scalarRead = rewriter.create( + op.getLoc(), splitArrTy.getElementType(), getSingleConvertedValue(splitArrRange), indices + ); + leafValues[id] = scalarRead; + } + + SmallVector recordChain; + for (RecordAttr record : podTy.getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = rebuildFlattenedPodRecord( + op.getLoc(), record.getType(), recordChain, leafValues, rewriter + ); + genWrite(op.getLoc(), pod, record.getName(), recordValue, rewriter); + recordChain.pop_back(); + } + rewriter.replaceOp(op, pod); + return success(); + } +}; + +/// Split `array.write` to an array-of-POD into one write per parallel leaf array. +class SplitPodArrayWriteArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(WriteArrayOp op) { return !splittablePodArray(op.getArrRefType()); } + + LogicalResult matchAndRewrite( + WriteArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + ArrayType arrTy = op.getArrRefType(); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + Value podValue = getSingleConvertedValue(adaptor.getRvalue()); + for (auto [id, splitArrRange, splitType] : + llvm::zip_equal(splitIds, adaptor.getArrRef(), splitTypes)) { + Value leafValue = genReadAlongPath(op.getLoc(), podValue, id, rewriter); + rewriter.create( + op.getLoc(), getSingleConvertedValue(splitArrRange), indices, leafValue + ); + } + rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite array-of-POD function signatures to use one parallel array per POD leaf. +class SplitPodArrayInFuncDefOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(FuncDefOp op) { + return !containsSplittablePodArrayType(op.getArgumentTypes()) && + !containsSplittablePodArrayType(op.getResultTypes()); + } + + LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); } + + LogicalResult + matchAndRewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + auto *typeConverter = getTypeConverter(); + assert(typeConverter && "expected pod-array type converter"); + + FunctionType oldTy = op.getFunctionType(); + TypeConverter::SignatureConversion inputConversion(oldTy.getNumInputs()); + if (failed(typeConverter->convertSignatureArgs(oldTy.getInputs(), inputConversion))) { + return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod inputs"); + } + + SmallVector newResults; + if (failed(typeConverter->convertTypes(oldTy.getResults(), newResults))) { + return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod results"); + } + + if (!op.getBody().empty() && + failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, &inputConversion))) { + return rewriter.notifyMatchFailure(op, "failed to convert function body block arguments"); + } + + SmallVector originalInputIdxToSize, originalResultIdxToSize; + SmallVector newInputs = splitPodArrayType(oldTy.getInputs(), &originalInputIdxToSize); + SplitFunctionNameInfo inputNameInfo = + collectSplitFunctionNameInfo(op.getArgumentTypes(), [&](unsigned i) { + return op.getArgNameAttr(i); + }, getSplitPodArrayRecordNameSuffixes); + ArrayAttr resultAttrs = op.getAllResultAttrs(); + SplitFunctionNameInfo resultNameInfo = + collectSplitFunctionNameInfo(op.getResultTypes(), [resultAttrs](unsigned i) { + return getAttrAtIndexWithName(resultAttrs, i, RES_NAME_ATTR_NAME); + }, getSplitPodArrayRecordNameSuffixes); + + rewriter.modifyOpInPlace(op, [&]() { + op.setFunctionType(FunctionType::get(op.getContext(), newInputs, newResults)); + if (ArrayAttr newArgAttrs = replicateFunctionNameAttrsAsNeeded( + op.getArgAttrsAttr(), originalInputIdxToSize, newInputs, ARG_NAME_ATTR_NAME, + inputNameInfo.originalNames, inputNameInfo.existingNames, + inputNameInfo.splitNameSuffixes + )) { + op.setArgAttrsAttr(newArgAttrs); + } + if (ArrayAttr newResAttrs = replicateFunctionNameAttrsAsNeeded( + op.getResAttrsAttr(), originalResultIdxToSize, newResults, RES_NAME_ATTR_NAME, + resultNameInfo.originalNames, resultNameInfo.existingNames, + resultNameInfo.splitNameSuffixes + )) { + op.setResAttrsAttr(newResAttrs); + } + }); + return success(); + } +}; - LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split pod-type members\n";); +/// Rewrite `function.return` to flatten any array-of-POD operands into their parallel arrays. +class SplitPodArrayInReturnOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ReturnOp op) { + return !containsSplittablePodArrayType(op.getOperands().getTypes()); + } + + LogicalResult matchAndRewrite( + ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + SmallVector newOperands = flattenConvertedValues(adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, ValueRange(newOperands)); + return success(); + } +}; + +/// Rewrite calls whose arguments or results contain arrays-of-POD to use the split signature. +class SplitPodArrayInCallOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(CallOp op) { + return !containsSplittablePodArrayType(op.getArgOperands().getTypes()) && + !containsSplittablePodArrayType(op.getResultTypes()); + } + + LogicalResult match(CallOp op) const override { return failure(legal(op)); } + + LogicalResult matchAndRewrite( + CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + auto *typeConverter = getTypeConverter(); + assert(typeConverter && "expected pod-array type converter"); + + SmallVector newResultTypes; + if (failed(typeConverter->convertTypes(op.getResultTypes(), newResultTypes))) { + return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod call results"); + } + + SmallVector> mapOperandStorage; + SmallVector mapOperands; + mapOperandStorage.reserve(adaptor.getMapOperands().size()); + mapOperands.reserve(adaptor.getMapOperands().size()); + for (ArrayRef mapOperandGroup : adaptor.getMapOperands()) { + mapOperandStorage.push_back(flattenConvertedValues(mapOperandGroup)); + } + for (const SmallVector &values : mapOperandStorage) { + mapOperands.push_back(values); + } + + SmallVector newArgOperands = flattenConvertedValues(adaptor.getArgOperands()); + CallOp newCall = createCallPreservingInstantiationOperands( + op.getLoc(), newResultTypes, op, mapOperands, newArgOperands, rewriter + ); + + SmallVector> replacementStorage; + replacementStorage.reserve(op.getNumResults()); + auto newResultIt = newCall.getResults().begin(); + for (Type oldResultType : op.getResultTypes()) { + SmallVector convertedTypes; + (void)splitPodArrayTypeTo(oldResultType, convertedTypes); + SmallVector replacementsForResult; + replacementsForResult.reserve(convertedTypes.size()); + for (size_t i = 0; i < convertedTypes.size(); ++i) { + replacementsForResult.push_back(*newResultIt); + ++newResultIt; + } + replacementStorage.push_back(std::move(replacementsForResult)); + } + + SmallVector replacements; + replacements.reserve(replacementStorage.size()); + for (const SmallVector &values : replacementStorage) { + replacements.push_back(values); + } + rewriter.replaceOpWithMultiple(op, replacements); + return success(); + } +}; + +/// Replace `array.length` on an array-of-POD with the equivalent length of any split leaf array. +class SplitPodArrayLengthOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ArrayLengthOp op) { return !splittablePodArray(op.getArrRefType()); } + + LogicalResult matchAndRewrite( + ArrayLengthOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, getSingleConvertedValue(adaptor.getArrRef()), getSingleConvertedValue(adaptor.getDim()) + ); + return success(); + } +}; + +/// Rewrite a write to a split array-of-POD struct member into writes to each parallel array member. +class SplitPodArrayInMemberWriteOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + +public: + SplitPodArrayInMemberWriteOp( + const TypeConverter &converter, MLIRContext *ctx, SymbolTableCollection &symTables, + const MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(converter, ctx), tables(symTables), + repMapRef(memberRepMap) {} + + static bool legal(MemberWriteOp op) { return !splittablePodArray(op.getVal().getType()); } + + LogicalResult matchAndRewrite( + MemberWriteOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + ArrayType arrTy = llvm::cast(op.getVal().getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + for (auto [id, splitValRange] : llvm::zip_equal(splitIds, adaptor.getVal())) { + const MemberInfo &newMember = idToMember.at(id); + rewriter.create( + op.getLoc(), getSingleConvertedValue(adaptor.getComponent()), + FlatSymbolRefAttr::get(newMember.first), getSingleConvertedValue(splitValRange) + ); + } + rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite a read from a split array-of-POD struct member into reads of each parallel array member. +class SplitPodArrayInMemberReadOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + +public: + SplitPodArrayInMemberReadOp( + const TypeConverter &converter, MLIRContext *ctx, SymbolTableCollection &symTables, + const MemberReplacementMap &memberRepMap + ) + : OpConversionPattern(converter, ctx), tables(symTables), + repMapRef(memberRepMap) {} + + static bool legal(MemberReadOp op) { return !splittablePodArray(op.getResult().getType()); } + + LogicalResult matchAndRewrite( + MemberReadOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + ArrayType arrTy = llvm::cast(op.getType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector replacements; + replacements.reserve(splitIds.size()); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + const MemberInfo &newMember = idToMember.at(id); + replacements.push_back(rewriter.create( + op.getLoc(), splitType, getSingleConvertedValue(adaptor.getComponent()), newMember.first + )); + } + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + +/// Split arrays-of-POD into parallel arrays before direct pod scalarization. +static LogicalResult +step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { + MLIRContext *ctx = modOp.getContext(); + PodArrayTypeConverter typeConverter; + + RewritePatternSet patterns(ctx); + patterns.add< + SplitPodArrayNonDetOp, SplitPodArrayCreateArrayOp, SplitPodArrayReadArrayOp, + SplitPodArrayWriteArrayOp, SplitPodArrayInFuncDefOp, SplitPodArrayInReturnOp, + SplitPodArrayInCallOp, SplitPodArrayLengthOp>(typeConverter, ctx); + patterns.add( + typeConverter, ctx, symTables, memberRepMap + ); + + ConversionTarget target(*ctx); + baseTargetSetup(target); + target.addDynamicallyLegalOp(SplitPodArrayNonDetOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayCreateArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayReadArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayWriteArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInFuncDefOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInReturnOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInCallOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayLengthOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInMemberWriteOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInMemberReadOp::legal); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); + + LLVM_DEBUG(llvm::dbgs() << "Begin step 2: split arrays with POD element type\n";); return applyFullConversion(modOp, target, std::move(patterns)); } @@ -617,44 +1266,12 @@ class SplitPodInCallOp : public OpConversionPattern { } }; -/// Read a nested POD leaf by following each record name in `recordChain`. -static Value -genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &rewriter) { - Value value = podRef; - for (StringAttr attr : recordChain.nameList) { - value = genRead(loc, value, attr, rewriter); - } - return value; -} - /// State used while rebuilding a POD from flattened struct-member leaves. struct RebuildPodReadState { NewPodOp pod; DenseMap leafValues; }; -/// Reconstruct a POD record from the leaf values collected while splitting `struct.readm`. -static Value rebuildFlattenedPodRecord( - Location loc, Type recordType, SmallVectorImpl &recordChain, - const DenseMap &leafValues, ConversionPatternRewriter &rewriter -) { - if (PodType nestedPodTy = dyn_cast(recordType)) { - NewPodOp nestedPod = rewriter.create(loc, nestedPodTy); - for (RecordAttr record : nestedPodTy.getRecords()) { - recordChain.push_back(record.getName()); - Value recordValue = - rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, rewriter); - genWrite(loc, nestedPod, record.getName(), recordValue, rewriter); - recordChain.pop_back(); - } - return nestedPod; - } - - auto it = leafValues.find(RecordChain(recordChain)); - assert(it != leafValues.end() && "missing flattened POD leaf value"); - return it->second; -} - /// Rewrite a write to a pod-typed struct member into writes to the corresponding scalar leaves. class SplitPodInMemberWriteOp : public SplitAggregateInMemberRefOp< SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain> { @@ -726,7 +1343,7 @@ class SplitPodInMemberReadOp /// Special handling to split pods in struct member refs and function signatures and desugar /// initializations on pod.new into pod writes. static LogicalResult -step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { +step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { MLIRContext *ctx = modOp.getContext(); RewritePatternSet patterns(ctx); @@ -800,6 +1417,26 @@ static bool hasValueUse(Operation &op, Value value) { }); } +/// Return the nearest preceding same-record write that can be forwarded to `readOp`. +/// +/// This fold is intentionally conservative: it only forwards through intervening operations that do +/// not use the POD value at all. That keeps the rewrite local and avoids reasoning about other +/// whole-POD uses or record accesses that may observe mutation ordering. +static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp) { + Value podRef = readOp.getPodRef(); + StringAttr recordName = getRecordNameAsStringAttr(readOp); + + for (Operation *op = readOp->getPrevNode(); op; op = op->getPrevNode()) { + if (!hasValueUse(*op, podRef)) { + continue; + } + + auto writeOp = dyn_cast(op); + return writeOp && isSamePodRecord(writeOp, podRef, recordName) ? writeOp : nullptr; + } + return nullptr; +} + /// Return whether the read is preceded by a write to the same pod record within its block. static bool hasEarlierWriteInBlock(ReadPodOp readOp) { Value podRef = readOp.getPodRef(); @@ -876,6 +1513,20 @@ static WritePodOp findPrecedingWriteForIfRead(ReadPodOp readOp) { return replacement; } +/// Replace a read with the value from the nearest preceding same-record write in the block. +class FoldReadAfterWriteInBlockPattern final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReadPodOp readOp, PatternRewriter &rewriter) const override { + if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { + rewriter.replaceOp(readOp, writeOp.getValue()); + return success(); + } + return failure(); + } +}; + /// Replace a branch-local read with a value available in the parent block. class ReplaceIfReadPattern final : public OpRewritePattern { public: @@ -1559,13 +2210,12 @@ applyGreedily(ModuleOp modOp, RewritePatternSet &&patterns, bool *changed = null /// Repeatedly lift pod accesses out of supported SCF regions so SROA + mem2reg can eliminate the /// remaining POD storage. -static LogicalResult step3(ModuleOp modOp) { +static LogicalResult step4(ModuleOp modOp) { RewritePatternSet patterns(modOp.getContext()); patterns.add< - ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, LiftPodAccessesFromForLoopPattern, - LiftPodAccessesFromWhileLoopPattern, FoldIfCarriedPodReadAfterWritePattern>( - patterns.getContext() - ); + FoldReadAfterWriteInBlockPattern, ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, + LiftPodAccessesFromForLoopPattern, LiftPodAccessesFromWhileLoopPattern, + FoldIfCarriedPodReadAfterWritePattern>(patterns.getContext()); LLVM_DEBUG(llvm::dbgs() << "Begin step 3: refactor pod ops within SCF regions\n";); return applyGreedily(modOp, std::move(patterns)); @@ -1648,13 +2298,21 @@ class PassImpl : public llzk::pod::impl::PodToScalarPassBase { llvm::dbgs() << "After step 2:\n"; module.dump(); }); + + if (failed(step3(module, symTables, memberRepMap))) { + return signalPassFailure(); + } + LLVM_DEBUG({ + llvm::dbgs() << "After step 3:\n"; + module.dump(); + }); } - if (failed(step3(module))) { + if (failed(step4(module))) { return signalPassFailure(); } LLVM_DEBUG({ - llvm::dbgs() << "After step 3:\n"; + llvm::dbgs() << "After step 4:\n"; module.dump(); }); diff --git a/lib/Transforms/LLZKTransformationPassPipelines.cpp b/lib/Transforms/LLZKTransformationPassPipelines.cpp index 1c87cfcb6..1de684a55 100644 --- a/lib/Transforms/LLZKTransformationPassPipelines.cpp +++ b/lib/Transforms/LLZKTransformationPassPipelines.cpp @@ -49,8 +49,8 @@ void buildFullStructInliningPipelineImpl( } pm.addPass(polymorphic::createFlatteningPass(flattening)); - // Run array-to-scalar first because it can split arrays within a pod - // but pod-to-scalar cannot split pods within an array. + // Run array-to-scalar first because it can still scalarize plain array structure nested inside + // pods before pod-to-scalar rewrites arrays whose element type is a POD into parallel arrays. if (arrayToScalar) { pm.addPass(array::createArrayToScalarPass()); } diff --git a/test/Transforms/PodToScalar/circom_decomp_prod.llzk b/test/Transforms/PodToScalar/circom_decomp_prod.llzk index 9c2d0a87c..7a1e01dd2 100644 --- a/test/Transforms/PodToScalar/circom_decomp_prod.llzk +++ b/test/Transforms/PodToScalar/circom_decomp_prod.llzk @@ -288,61 +288,55 @@ module attributes {llzk.lang = "circom", llzk.main = !struct.type<@DecomposeProd // CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Num2Bits_0> // CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> // CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_7:[0-9a-zA-Z_\.]+]] = %[[VAL_4]], %[[VAL_8:[0-9a-zA-Z_\.]+]] = %[[VAL_3]], %[[VAL_9:[0-9a-zA-Z_\.]+]] = %[[VAL_5]]) : (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) -> (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) { -// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> -// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_9]], %[[VAL_10]]) : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.condition(%[[VAL_11]]) %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_5:[0-9a-zA-Z_\.]+]] = %[[VAL_3]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_5]], %[[VAL_6]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_7]]) %[[VAL_5]] : !felt.type<"bn128"> // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[VAL_12:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_13:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_14:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): -// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = felt.shr %[[VAL_0]], %[[VAL_14]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = felt.bit_and %[[VAL_15]], %[[VAL_16]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_14]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_2]]{{\[}}%[[VAL_18]]] = %[[VAL_17]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_14]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_2]]{{\[}}%[[VAL_19]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_20]], %[[VAL_4]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_13]], %[[VAL_21]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_12]], %[[VAL_12]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_14]], %[[VAL_24]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.yield %[[VAL_23]], %[[VAL_22]], %[[VAL_25]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: ^bb0(%[[VAL_8:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.shr %[[VAL_0]], %[[VAL_8]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = felt.bit_and %[[VAL_9]], %[[VAL_10]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_8]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_2]]{{\[}}%[[VAL_12]]] = %[[VAL_11]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_8]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_2]]{{\[}}%[[VAL_13]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_8]], %[[VAL_15]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_16]] : !felt.type<"bn128"> // CHECK-NEXT: } // CHECK-NEXT: struct.writem %[[VAL_1]][@out] = %[[VAL_2]] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> // CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_26:[0-9a-zA-Z_\.]+]]: !struct.type<@Num2Bits_0>, %[[VAL_27:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) attributes {function.allow_constraint, function.allow_non_native_field_ops} { -// CHECK-NEXT: %[[VAL_28:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_26]][@out] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_29:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: %[[VAL_30:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_31:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: %[[VAL_32:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_33:[0-9a-zA-Z_\.]+]] = %[[VAL_29]], %[[VAL_34:[0-9a-zA-Z_\.]+]] = %[[VAL_30]], %[[VAL_35:[0-9a-zA-Z_\.]+]] = %[[VAL_31]]) : (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) -> (!felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128">) { -// CHECK-NEXT: %[[VAL_36:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> -// CHECK-NEXT: %[[VAL_37:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_35]], %[[VAL_36]]) : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.condition(%[[VAL_37]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: function.def @constrain(%[[VAL_17:[0-9a-zA-Z_\.]+]]: !struct.type<@Num2Bits_0>, %[[VAL_18:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@out] : <@Num2Bits_0>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]]:2 = scf.while (%[[VAL_24:[0-9a-zA-Z_\.]+]] = %[[VAL_20]], %[[VAL_25:[0-9a-zA-Z_\.]+]] = %[[VAL_22]]) : (!felt.type<"bn128">, !felt.type<"bn128">) -> (!felt.type<"bn128">, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_26:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_27:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_25]], %[[VAL_26]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_27]]) %[[VAL_24]], %[[VAL_25]] : !felt.type<"bn128">, !felt.type<"bn128"> // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[VAL_38:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_39:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_40:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): -// CHECK-NEXT: %[[VAL_41:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_40]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_42:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_28]]{{\[}}%[[VAL_41]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_43:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_40]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_44:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_28]]{{\[}}%[[VAL_43]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_45:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_46:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_44]], %[[VAL_45]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_47:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_42]], %[[VAL_46]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_48:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: constrain.eq %[[VAL_47]], %[[VAL_48]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_49:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_40]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_50:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_28]]{{\[}}%[[VAL_49]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_51:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_50]], %[[VAL_30]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_52:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_38]], %[[VAL_51]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_53:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_39]], %[[VAL_39]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_54:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_55:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_40]], %[[VAL_54]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.yield %[[VAL_52]], %[[VAL_53]], %[[VAL_55]] : !felt.type<"bn128">, !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: ^bb0(%[[VAL_28:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">, %[[VAL_29:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_30:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_29]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_31:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_30]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_32:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_29]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_33:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_32]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_34:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_35:[0-9a-zA-Z_\.]+]] = felt.sub %[[VAL_33]], %[[VAL_34]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_36:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_31]], %[[VAL_35]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_37:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_36]], %[[VAL_37]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_38:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_29]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_39:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_38]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_40:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_39]], %[[VAL_21]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_41:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_28]], %[[VAL_40]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_42:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_43:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_29]], %[[VAL_42]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_41]], %[[VAL_43]] : !felt.type<"bn128">, !felt.type<"bn128"> // CHECK-NEXT: } -// CHECK-NEXT: constrain.eq %[[VAL_32]]#0, %[[VAL_27]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_23]]#0, %[[VAL_18]] : !felt.type<"bn128">, !felt.type<"bn128"> // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } @@ -351,205 +345,199 @@ module attributes {llzk.lang = "circom", llzk.main = !struct.type<@DecomposeProd // CHECK-NEXT: struct.member @low : !array.type<8 x !felt.type<"bn128">> {llzk.pub} // CHECK-NEXT: struct.member @u16s : !array.type<8 x !felt.type<"bn128">> // CHECK-NEXT: struct.member @bits_low : !array.type<8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: struct.member @bits_low$inputs : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> +// CHECK-NEXT: struct.member @bits_low$inputs_in : !array.type<8 x !felt.type<"bn128">> // CHECK-NEXT: struct.member @bits_high : !array.type<8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: struct.member @bits_high$inputs : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: function.def @compute(%[[VAL_56:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_57:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) -> !struct.type<@DecomposeProduct_1> attributes {function.allow_non_native_field_ops, function.allow_witness} { -// CHECK-NEXT: %[[VAL_58:[0-9a-zA-Z_\.]+]] = struct.new : <@DecomposeProduct_1> -// CHECK-NEXT: %[[VAL_59:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_60:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_61:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_62:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> -// CHECK-NEXT: %[[VAL_63:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index -// CHECK-NEXT: %[[VAL_64:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_65:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: scf.for %[[VAL_66:[0-9a-zA-Z_\.]+]] = %[[VAL_64]] to %[[VAL_63]] step %[[VAL_65]] { -// CHECK-NEXT: %[[VAL_67:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_62]]{{\[}}%[[VAL_66]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: %[[VAL_68:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: pod.write %[[VAL_67]][@count] = %[[VAL_68]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index -// CHECK-NEXT: array.write %[[VAL_62]]{{\[}}%[[VAL_66]]] = %[[VAL_67]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: struct.member @bits_high$inputs_in : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.def @compute(%[[VAL_44:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_45:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) -> !struct.type<@DecomposeProduct_1> attributes {function.allow_non_native_field_ops, function.allow_witness} { +// CHECK-NEXT: %[[VAL_46:[0-9a-zA-Z_\.]+]] = struct.new : <@DecomposeProduct_1> +// CHECK-NEXT: %[[VAL_47:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_48:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_49:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_50:[0-9a-zA-Z_\.]+]] = array.new : <8 x index> +// CHECK-NEXT: %[[VAL_51:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_52:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_53:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_54:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_55:[0-9a-zA-Z_\.]+]] = %[[VAL_53]] to %[[VAL_52]] step %[[VAL_54]] { +// CHECK-NEXT: %[[VAL_56:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_50]]{{\[}}%[[VAL_55]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_57:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_51]]{{\[}}%[[VAL_55]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_58:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[VAL_50]]{{\[}}%[[VAL_55]]] = %[[VAL_58]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_51]]{{\[}}%[[VAL_55]]] = %[[VAL_57]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: %[[VAL_69:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: %[[VAL_70:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>> -// CHECK-NEXT: %[[VAL_71:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index -// CHECK-NEXT: %[[VAL_72:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_73:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: scf.for %[[VAL_74:[0-9a-zA-Z_\.]+]] = %[[VAL_72]] to %[[VAL_71]] step %[[VAL_73]] { -// CHECK-NEXT: %[[VAL_75:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_70]]{{\[}}%[[VAL_74]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: %[[VAL_76:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: pod.write %[[VAL_75]][@count] = %[[VAL_76]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index -// CHECK-NEXT: array.write %[[VAL_70]]{{\[}}%[[VAL_74]]] = %[[VAL_75]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> +// CHECK-NEXT: %[[VAL_59:[0-9a-zA-Z_\.]+]] = array.new : <8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_60:[0-9a-zA-Z_\.]+]] = array.new : <8 x index> +// CHECK-NEXT: %[[VAL_61:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_62:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_63:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_64:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_65:[0-9a-zA-Z_\.]+]] = %[[VAL_63]] to %[[VAL_62]] step %[[VAL_64]] { +// CHECK-NEXT: %[[VAL_66:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_65]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_67:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_65]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_68:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[VAL_60]]{{\[}}%[[VAL_65]]] = %[[VAL_68]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_61]]{{\[}}%[[VAL_65]]] = %[[VAL_67]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: %[[VAL_77:[0-9a-zA-Z_\.]+]] = array.new : <8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: %[[VAL_78:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: %[[VAL_79:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_80:[0-9a-zA-Z_\.]+]] = %[[VAL_77]], %[[VAL_81:[0-9a-zA-Z_\.]+]] = %[[VAL_69]], %[[VAL_82:[0-9a-zA-Z_\.]+]] = %[[VAL_78]]) : (!array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128">) -> (!array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128">) { -// CHECK-NEXT: %[[VAL_83:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> -// CHECK-NEXT: %[[VAL_84:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_82]], %[[VAL_83]]) : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_69:[0-9a-zA-Z_\.]+]] = array.new : <8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_70:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_71:[0-9a-zA-Z_\.]+]]:3 = scf.while (%[[VAL_72:[0-9a-zA-Z_\.]+]] = %[[VAL_69]], %[[VAL_73:[0-9a-zA-Z_\.]+]] = %[[VAL_59]], %[[VAL_74:[0-9a-zA-Z_\.]+]] = %[[VAL_70]]) : (!array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128">) -> (!array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_75:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_76:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_74]], %[[VAL_75]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_76]]) %[[VAL_72]], %[[VAL_73]], %[[VAL_74]] : !array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128"> // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[VAL_85:[0-9a-zA-Z_\.]+]]: !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, %[[VAL_86:[0-9a-zA-Z_\.]+]]: !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, %[[VAL_87:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): -// CHECK-NEXT: %[[VAL_88:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_89:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_56]]{{\[}}%[[VAL_88]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_90:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_91:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_57]]{{\[}}%[[VAL_90]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_92:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_89]], %[[VAL_91]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_93:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_59]]{{\[}}%[[VAL_93]]] = %[[VAL_92]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_94:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_95:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_59]]{{\[}}%[[VAL_94]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_96:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> -// CHECK-NEXT: %[[VAL_97:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_95]], %[[VAL_96]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_98:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_60]]{{\[}}%[[VAL_98]]] = %[[VAL_97]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_99:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_100:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_99]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_101:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_102:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_86]]{{\[}}%[[VAL_101]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: pod.write %[[VAL_102]][@in] = %[[VAL_100]] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_103:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_86]]{{\[}}%[[VAL_103]]] = %[[VAL_102]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: %[[VAL_104:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_105:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_62]]{{\[}}%[[VAL_104]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: %[[VAL_106:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_105]][@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index -// CHECK-NEXT: %[[VAL_107:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: %[[VAL_108:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_106]], %[[VAL_107]] : index -// CHECK-NEXT: pod.write %[[VAL_105]][@count] = %[[VAL_108]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index -// CHECK-NEXT: %[[VAL_109:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_110:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_108]], %[[VAL_109]] : index -// CHECK-NEXT: scf.if %[[VAL_110]] { -// CHECK-NEXT: %[[VAL_111:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_100]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> -// CHECK-NEXT: pod.write %[[VAL_105]][@comp] = %[[VAL_111]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> -// CHECK-NEXT: %[[VAL_112:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_62]]{{\[}}%[[VAL_112]]] = %[[VAL_105]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: } else { +// CHECK-NEXT: ^bb0(%[[VAL_77:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_78:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_79:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_80:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_81:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_44]]{{\[}}%[[VAL_80]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_82:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_83:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_45]]{{\[}}%[[VAL_82]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_84:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_81]], %[[VAL_83]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_85:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_47]]{{\[}}%[[VAL_85]]] = %[[VAL_84]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_86:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_87:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_47]]{{\[}}%[[VAL_86]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_88:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_89:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_87]], %[[VAL_88]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_90:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_48]]{{\[}}%[[VAL_90]]] = %[[VAL_89]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_91:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_92:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_48]]{{\[}}%[[VAL_91]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_93:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_94:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_78]]{{\[}}%[[VAL_93]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_95:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_78]]{{\[}}%[[VAL_95]]] = %[[VAL_92]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_96:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_97:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_50]]{{\[}}%[[VAL_96]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_98:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_51]]{{\[}}%[[VAL_96]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_99:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_100:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_97]], %[[VAL_99]] : index +// CHECK-NEXT: %[[VAL_101:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_102:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_100]], %[[VAL_101]] : index +// CHECK-NEXT: scf.if %[[VAL_102]] { +// CHECK-NEXT: %[[VAL_103:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_92]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_104:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_50]]{{\[}}%[[VAL_104]]] = %[[VAL_100]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_51]]{{\[}}%[[VAL_104]]] = %[[VAL_103]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: %[[VAL_113:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_114:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_59]]{{\[}}%[[VAL_113]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_115:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> -// CHECK-NEXT: %[[VAL_116:[0-9a-zA-Z_\.]+]] = felt.uintdiv %[[VAL_114]], %[[VAL_115]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_117:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> -// CHECK-NEXT: %[[VAL_118:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_116]], %[[VAL_117]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_119:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_61]]{{\[}}%[[VAL_119]]] = %[[VAL_118]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_120:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_121:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_120]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_122:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_123:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_85]]{{\[}}%[[VAL_122]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: pod.write %[[VAL_123]][@in] = %[[VAL_121]] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_124:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_85]]{{\[}}%[[VAL_124]]] = %[[VAL_123]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: %[[VAL_125:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_126:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_70]]{{\[}}%[[VAL_125]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: %[[VAL_127:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_126]][@count] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index -// CHECK-NEXT: %[[VAL_128:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: %[[VAL_129:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_127]], %[[VAL_128]] : index -// CHECK-NEXT: pod.write %[[VAL_126]][@count] = %[[VAL_129]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, index -// CHECK-NEXT: %[[VAL_130:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_131:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_129]], %[[VAL_130]] : index -// CHECK-NEXT: scf.if %[[VAL_131]] { -// CHECK-NEXT: %[[VAL_132:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_121]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> -// CHECK-NEXT: pod.write %[[VAL_126]][@comp] = %[[VAL_132]] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> -// CHECK-NEXT: %[[VAL_133:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_87]] : !felt.type<"bn128"> -// CHECK-NEXT: array.write %[[VAL_70]]{{\[}}%[[VAL_133]]] = %[[VAL_126]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: } else { +// CHECK-NEXT: %[[VAL_105:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_106:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_47]]{{\[}}%[[VAL_105]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_107:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_108:[0-9a-zA-Z_\.]+]] = felt.uintdiv %[[VAL_106]], %[[VAL_107]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_109:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_110:[0-9a-zA-Z_\.]+]] = felt.umod %[[VAL_108]], %[[VAL_109]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_111:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_49]]{{\[}}%[[VAL_111]]] = %[[VAL_110]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_112:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_113:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_49]]{{\[}}%[[VAL_112]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_114:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_115:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_77]]{{\[}}%[[VAL_114]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_116:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_77]]{{\[}}%[[VAL_116]]] = %[[VAL_113]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_117:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_118:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_117]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_119:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_117]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_120:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_121:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_118]], %[[VAL_120]] : index +// CHECK-NEXT: %[[VAL_122:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_123:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_121]], %[[VAL_122]] : index +// CHECK-NEXT: scf.if %[[VAL_123]] { +// CHECK-NEXT: %[[VAL_124:[0-9a-zA-Z_\.]+]] = function.call @Num2Bits_0::@compute(%[[VAL_113]]) : (!felt.type<"bn128">) -> !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_125:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_79]] : !felt.type<"bn128"> +// CHECK-NEXT: array.write %[[VAL_60]]{{\[}}%[[VAL_125]]] = %[[VAL_121]] : <8 x index>, index +// CHECK-NEXT: array.write %[[VAL_61]]{{\[}}%[[VAL_125]]] = %[[VAL_124]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: %[[VAL_134:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_135:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_87]], %[[VAL_134]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.yield %[[VAL_85]], %[[VAL_86]], %[[VAL_135]] : !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_126:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_127:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_79]], %[[VAL_126]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_77]], %[[VAL_78]], %[[VAL_127]] : !array.type<8 x !felt.type<"bn128">>, !array.type<8 x !felt.type<"bn128">>, !felt.type<"bn128"> // CHECK-NEXT: } -// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_high$inputs] = %[[VAL_79]]#0 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: %[[VAL_136:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: %[[VAL_137:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index -// CHECK-NEXT: %[[VAL_138:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_139:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: scf.for %[[VAL_140:[0-9a-zA-Z_\.]+]] = %[[VAL_138]] to %[[VAL_137]] step %[[VAL_139]] { -// CHECK-NEXT: %[[VAL_141:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_70]]{{\[}}%[[VAL_140]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: %[[VAL_142:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_141]][@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> -// CHECK-NEXT: array.write %[[VAL_136]]{{\[}}%[[VAL_140]]] = %[[VAL_142]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_high$inputs_in] = %[[VAL_71]]#0 : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_128:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_129:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_130:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_131:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_132:[0-9a-zA-Z_\.]+]] = %[[VAL_130]] to %[[VAL_129]] step %[[VAL_131]] { +// CHECK-NEXT: %[[VAL_133:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_60]]{{\[}}%[[VAL_132]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_134:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_61]]{{\[}}%[[VAL_132]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: array.write %[[VAL_128]]{{\[}}%[[VAL_132]]] = %[[VAL_134]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_high] = %[[VAL_136]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_low$inputs] = %[[VAL_79]]#1 : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: %[[VAL_143:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: %[[VAL_144:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index -// CHECK-NEXT: %[[VAL_145:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_146:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: scf.for %[[VAL_147:[0-9a-zA-Z_\.]+]] = %[[VAL_145]] to %[[VAL_144]] step %[[VAL_146]] { -// CHECK-NEXT: %[[VAL_148:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_62]]{{\[}}%[[VAL_147]]] : <8 x !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>>, !pod.type<[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]> -// CHECK-NEXT: %[[VAL_149:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_148]][@comp] : <[@count: index, @comp: !struct.type<@Num2Bits_0>, @params: !pod.type<[]>]>, !struct.type<@Num2Bits_0> -// CHECK-NEXT: array.write %[[VAL_143]]{{\[}}%[[VAL_147]]] = %[[VAL_149]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_high] = %[[VAL_128]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_low$inputs_in] = %[[VAL_71]]#1 : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_135:[0-9a-zA-Z_\.]+]] = array.new : <8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_136:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_137:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_138:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_139:[0-9a-zA-Z_\.]+]] = %[[VAL_137]] to %[[VAL_136]] step %[[VAL_138]] { +// CHECK-NEXT: %[[VAL_140:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_50]]{{\[}}%[[VAL_139]]] : <8 x index>, index +// CHECK-NEXT: %[[VAL_141:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_51]]{{\[}}%[[VAL_139]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: array.write %[[VAL_135]]{{\[}}%[[VAL_139]]] = %[[VAL_141]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> // CHECK-NEXT: } -// CHECK-NEXT: struct.writem %[[VAL_58]][@bits_low] = %[[VAL_143]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: struct.writem %[[VAL_58]][@u16s] = %[[VAL_59]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: struct.writem %[[VAL_58]][@low] = %[[VAL_60]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: struct.writem %[[VAL_58]][@high] = %[[VAL_61]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: function.return %[[VAL_58]] : !struct.type<@DecomposeProduct_1> +// CHECK-NEXT: struct.writem %[[VAL_46]][@bits_low] = %[[VAL_135]] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: struct.writem %[[VAL_46]][@u16s] = %[[VAL_47]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_46]][@low] = %[[VAL_48]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_46]][@high] = %[[VAL_49]] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: function.return %[[VAL_46]] : !struct.type<@DecomposeProduct_1> // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_150:[0-9a-zA-Z_\.]+]]: !struct.type<@DecomposeProduct_1>, %[[VAL_151:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_152:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) attributes {function.allow_constraint, function.allow_non_native_field_ops} { -// CHECK-NEXT: %[[VAL_153:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@high] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_154:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@low] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_155:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@u16s] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> -// CHECK-NEXT: %[[VAL_156:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_low] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: %[[VAL_157:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_low$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: %[[VAL_158:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_high] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> -// CHECK-NEXT: %[[VAL_159:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_150]][@bits_high$inputs] : <@DecomposeProduct_1>, !array.type<8 x !pod.type<[@in: !felt.type<"bn128">]>> -// CHECK-NEXT: %[[VAL_160:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> -// CHECK-NEXT: %[[VAL_161:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_162:[0-9a-zA-Z_\.]+]] = %[[VAL_160]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> { -// CHECK-NEXT: %[[VAL_163:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> -// CHECK-NEXT: %[[VAL_164:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_162]], %[[VAL_163]]) : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.condition(%[[VAL_164]]) %[[VAL_162]] : !felt.type<"bn128"> +// CHECK-NEXT: function.def @constrain(%[[VAL_142:[0-9a-zA-Z_\.]+]]: !struct.type<@DecomposeProduct_1>, %[[VAL_143:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>, %[[VAL_144:[0-9a-zA-Z_\.]+]]: !array.type<8 x !felt.type<"bn128">>) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_145:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@high] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_146:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@low] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_147:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@u16s] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_148:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_low] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_149:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_low$inputs_in] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_150:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_high] : <@DecomposeProduct_1>, !array.type<8 x !struct.type<@Num2Bits_0>> +// CHECK-NEXT: %[[VAL_151:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_142]][@bits_high$inputs_in] : <@DecomposeProduct_1>, !array.type<8 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_152:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_153:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_154:[0-9a-zA-Z_\.]+]] = %[[VAL_152]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_155:[0-9a-zA-Z_\.]+]] = felt.const 8 : <"bn128"> +// CHECK-NEXT: %[[VAL_156:[0-9a-zA-Z_\.]+]] = bool.cmp lt(%[[VAL_154]], %[[VAL_155]]) : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.condition(%[[VAL_156]]) %[[VAL_154]] : !felt.type<"bn128"> // CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[VAL_165:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): -// CHECK-NEXT: %[[VAL_166:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_167:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_151]]{{\[}}%[[VAL_166]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_168:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_169:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_152]]{{\[}}%[[VAL_168]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_170:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_167]], %[[VAL_169]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_171:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_172:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_155]]{{\[}}%[[VAL_171]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: ^bb0(%[[VAL_157:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">): +// CHECK-NEXT: %[[VAL_158:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_159:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_143]]{{\[}}%[[VAL_158]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_160:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_161:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_144]]{{\[}}%[[VAL_160]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_162:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_159]], %[[VAL_161]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_163:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_164:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_147]]{{\[}}%[[VAL_163]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_164]], %[[VAL_162]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_165:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_166:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_146]]{{\[}}%[[VAL_165]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_167:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_168:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_149]]{{\[}}%[[VAL_167]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_168]], %[[VAL_166]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_169:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_170:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_145]]{{\[}}%[[VAL_169]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_171:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_172:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_151]]{{\[}}%[[VAL_171]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> // CHECK-NEXT: constrain.eq %[[VAL_172]], %[[VAL_170]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_173:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_174:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_154]]{{\[}}%[[VAL_173]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_175:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_176:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_157]]{{\[}}%[[VAL_175]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: %[[VAL_177:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_176]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> -// CHECK-NEXT: constrain.eq %[[VAL_177]], %[[VAL_174]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_178:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_179:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_153]]{{\[}}%[[VAL_178]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_180:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_181:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_159]]{{\[}}%[[VAL_180]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: %[[VAL_182:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_181]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> -// CHECK-NEXT: constrain.eq %[[VAL_182]], %[[VAL_179]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_183:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_184:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_155]]{{\[}}%[[VAL_183]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_185:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_186:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_154]]{{\[}}%[[VAL_185]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_187:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> -// CHECK-NEXT: %[[VAL_188:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_165]] : !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_189:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_153]]{{\[}}%[[VAL_188]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_190:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_187]], %[[VAL_189]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_191:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_186]], %[[VAL_190]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: constrain.eq %[[VAL_184]], %[[VAL_191]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: %[[VAL_192:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> -// CHECK-NEXT: %[[VAL_193:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_165]], %[[VAL_192]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: scf.yield %[[VAL_193]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_173:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_174:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_147]]{{\[}}%[[VAL_173]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_175:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_176:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_146]]{{\[}}%[[VAL_175]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_177:[0-9a-zA-Z_\.]+]] = felt.const 256 : <"bn128"> +// CHECK-NEXT: %[[VAL_178:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_157]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_179:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_145]]{{\[}}%[[VAL_178]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_180:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_177]], %[[VAL_179]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_181:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_176]], %[[VAL_180]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_174]], %[[VAL_181]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_182:[0-9a-zA-Z_\.]+]] = felt.const 1 : <"bn128"> +// CHECK-NEXT: %[[VAL_183:[0-9a-zA-Z_\.]+]] = felt.add %[[VAL_157]], %[[VAL_182]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: scf.yield %[[VAL_183]] : !felt.type<"bn128"> // CHECK-NEXT: } -// CHECK-NEXT: %[[VAL_194:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index -// CHECK-NEXT: %[[VAL_195:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_196:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: scf.for %[[VAL_197:[0-9a-zA-Z_\.]+]] = %[[VAL_195]] to %[[VAL_194]] step %[[VAL_196]] { -// CHECK-NEXT: %[[VAL_198:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_158]]{{\[}}%[[VAL_197]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> -// CHECK-NEXT: %[[VAL_199:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_159]]{{\[}}%[[VAL_197]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: %[[VAL_200:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_199]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> -// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_198]], %[[VAL_200]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () +// CHECK-NEXT: %[[VAL_184:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_185:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_186:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_187:[0-9a-zA-Z_\.]+]] = %[[VAL_185]] to %[[VAL_184]] step %[[VAL_186]] { +// CHECK-NEXT: %[[VAL_188:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_150]]{{\[}}%[[VAL_187]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_189:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_151]]{{\[}}%[[VAL_187]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_188]], %[[VAL_189]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () // CHECK-NEXT: } -// CHECK-NEXT: %[[VAL_201:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index -// CHECK-NEXT: %[[VAL_202:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[VAL_203:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: scf.for %[[VAL_204:[0-9a-zA-Z_\.]+]] = %[[VAL_202]] to %[[VAL_201]] step %[[VAL_203]] { -// CHECK-NEXT: %[[VAL_205:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_156]]{{\[}}%[[VAL_204]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> -// CHECK-NEXT: %[[VAL_206:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_157]]{{\[}}%[[VAL_204]]] : <8 x !pod.type<[@in: !felt.type<"bn128">]>>, !pod.type<[@in: !felt.type<"bn128">]> -// CHECK-NEXT: %[[VAL_207:[0-9a-zA-Z_\.]+]] = pod.read %[[VAL_206]][@in] : <[@in: !felt.type<"bn128">]>, !felt.type<"bn128"> -// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_205]], %[[VAL_207]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () +// CHECK-NEXT: %[[VAL_190:[0-9a-zA-Z_\.]+]] = arith.constant 8 : index +// CHECK-NEXT: %[[VAL_191:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_192:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[VAL_193:[0-9a-zA-Z_\.]+]] = %[[VAL_191]] to %[[VAL_190]] step %[[VAL_192]] { +// CHECK-NEXT: %[[VAL_194:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_148]]{{\[}}%[[VAL_193]]] : <8 x !struct.type<@Num2Bits_0>>, !struct.type<@Num2Bits_0> +// CHECK-NEXT: %[[VAL_195:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_149]]{{\[}}%[[VAL_193]]] : <8 x !felt.type<"bn128">>, !felt.type<"bn128"> +// CHECK-NEXT: function.call @Num2Bits_0::@constrain(%[[VAL_194]], %[[VAL_195]]) : (!struct.type<@Num2Bits_0>, !felt.type<"bn128">) -> () // CHECK-NEXT: } // CHECK-NEXT: function.return // CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/function_calls_with_pod.llzk b/test/Transforms/PodToScalar/function_calls_with_pod.llzk index 6a24ab974..549f771ff 100644 --- a/test/Transforms/PodToScalar/function_calls_with_pod.llzk +++ b/test/Transforms/PodToScalar/function_calls_with_pod.llzk @@ -159,11 +159,11 @@ module attributes {llzk.lang} { } } // CHECK-LABEL: module attributes {llzk.lang} { -// CHECK-NEXT: function.def @id_array_of_pod(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x !pod.type<[@x: index]>>) -> !array.type<2 x !pod.type<[@x: index]>> { -// CHECK-NEXT: function.return %[[VAL_0]] : !array.type<2 x !pod.type<[@x: index]>> +// CHECK-NEXT: function.def @id_array_of_pod(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>) -> !array.type<2 x index> { +// CHECK-NEXT: function.return %[[VAL_0]] : !array.type<2 x index> // CHECK-NEXT: } -// CHECK-NEXT: function.def @main(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !pod.type<[@x: index]>>) -> !array.type<2 x !pod.type<[@x: index]>> { -// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @id_array_of_pod(%[[VAL_1]]) : (!array.type<2 x !pod.type<[@x: index]>>) -> !array.type<2 x !pod.type<[@x: index]>> -// CHECK-NEXT: function.return %[[VAL_2]] : !array.type<2 x !pod.type<[@x: index]>> +// CHECK-NEXT: function.def @main(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>) -> !array.type<2 x index> { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @id_array_of_pod(%[[VAL_1]]) : (!array.type<2 x index>) -> !array.type<2 x index> +// CHECK-NEXT: function.return %[[VAL_2]] : !array.type<2 x index> // CHECK-NEXT: } // CHECK-NEXT: } From 6793d6ed4e2180f0e37c00bf370ae7fb34c56342 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 18 Jun 2026 18:01:40 -0500 Subject: [PATCH 04/36] fix compile error from merge --- lib/Dialect/POD/Transforms/PodToScalarPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 1828ade0b..c8084bb6d 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1414,7 +1414,7 @@ static bool hasValueUse(Operation &op, Value value) { /// whole-POD uses or record accesses that may observe mutation ordering. static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp) { Value podRef = readOp.getPodRef(); - StringAttr recordName = getRecordNameAsStringAttr(readOp); + StringAttr recordName = readOp.getRecordNameAttr(); for (Operation *op = readOp->getPrevNode(); op; op = op->getPrevNode()) { if (!hasValueUse(*op, podRef)) { From 477c5322c07099f0ea2a149fa84f7c52da3e3afd Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 18 Jun 2026 18:04:01 -0500 Subject: [PATCH 05/36] fix shadowing warnings --- lib/Dialect/POD/Transforms/PodToScalarPass.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index c8084bb6d..09bf01177 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -784,22 +784,22 @@ class SplitPodArrayInFuncDefOp : public OpConversionPattern { LogicalResult matchAndRewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { - auto *typeConverter = getTypeConverter(); - assert(typeConverter && "expected pod-array type converter"); + auto *tyConv = getTypeConverter(); + assert(tyConv && "expected pod-array type converter"); FunctionType oldTy = op.getFunctionType(); TypeConverter::SignatureConversion inputConversion(oldTy.getNumInputs()); - if (failed(typeConverter->convertSignatureArgs(oldTy.getInputs(), inputConversion))) { + if (failed(tyConv->convertSignatureArgs(oldTy.getInputs(), inputConversion))) { return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod inputs"); } SmallVector newResults; - if (failed(typeConverter->convertTypes(oldTy.getResults(), newResults))) { + if (failed(tyConv->convertTypes(oldTy.getResults(), newResults))) { return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod results"); } if (!op.getBody().empty() && - failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, &inputConversion))) { + failed(rewriter.convertRegionTypes(&op.getBody(), *tyConv, &inputConversion))) { return rewriter.notifyMatchFailure(op, "failed to convert function body block arguments"); } @@ -872,11 +872,11 @@ class SplitPodArrayInCallOp : public OpConversionPattern { LogicalResult matchAndRewrite( CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter ) const override { - auto *typeConverter = getTypeConverter(); - assert(typeConverter && "expected pod-array type converter"); + auto *tyConv = getTypeConverter(); + assert(tyConv && "expected pod-array type converter"); SmallVector newResultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), newResultTypes))) { + if (failed(tyConv->convertTypes(op.getResultTypes(), newResultTypes))) { return rewriter.notifyMatchFailure(op, "failed to convert array-of-pod call results"); } From 3d01279b0dbef30ceaba881f3772a45d191efb7b Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 18 Jun 2026 18:28:21 -0500 Subject: [PATCH 06/36] reverse pipeline ordering based on this update --- lib/Transforms/LLZKTransformationPassPipelines.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Transforms/LLZKTransformationPassPipelines.cpp b/lib/Transforms/LLZKTransformationPassPipelines.cpp index 1de684a55..442bc9f69 100644 --- a/lib/Transforms/LLZKTransformationPassPipelines.cpp +++ b/lib/Transforms/LLZKTransformationPassPipelines.cpp @@ -49,14 +49,14 @@ void buildFullStructInliningPipelineImpl( } pm.addPass(polymorphic::createFlatteningPass(flattening)); - // Run array-to-scalar first because it can still scalarize plain array structure nested inside - // pods before pod-to-scalar rewrites arrays whose element type is a POD into parallel arrays. - if (arrayToScalar) { - pm.addPass(array::createArrayToScalarPass()); - } + // Run pod-to-scalar first because it is able to split `pod.type` used as array element type + // (into parallel arrays) so it should be able to fully remove all `pod.type` usages. if (podToScalar) { pm.addPass(pod::createPodToScalarPass()); } + if (arrayToScalar) { + pm.addPass(array::createArrayToScalarPass()); + } // Canonicalize to remove known-condition `scf.if` regions so struct inlining // can link "@compute" calls to struct members. pm.addPass(mlir::createCanonicalizerPass()); From e9a3bb257836d435aaaea6b714e27d85ac9ebb41 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Mon, 22 Jun 2026 08:59:36 -0500 Subject: [PATCH 07/36] clarify documentation --- .../POD/Transforms/TransformationPasses.td | 10 +++- .../POD/Transforms/PodToScalarPass.cpp | 57 ++++++++++--------- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/include/llzk/Dialect/POD/Transforms/TransformationPasses.td b/include/llzk/Dialect/POD/Transforms/TransformationPasses.td index 971d3bc93..65f93224d 100644 --- a/include/llzk/Dialect/POD/Transforms/TransformationPasses.td +++ b/include/llzk/Dialect/POD/Transforms/TransformationPasses.td @@ -15,7 +15,15 @@ include "llzk/Pass/PassBase.td" def PodToScalarPass : LLZKPass<"llzk-pod-to-scalar"> { let summary = "Replace PODs with scalar values"; let description = [{ - Replace `pod.type` values with the proper number of scalar values + Scalarize `pod.type` values by splitting POD-typed struct members into + multiple scalar members, splitting POD-typed array elements into parallel + arrays, then rewriting affected member accesses plus function signatures, + calls, and returns, and finally running POD-specific SROA + mem2reg cleanup + so the remaining POD storage is promoted to SSA values. + + If it is necessary to scalarize both PODs and arrays, run this pass before + running the `-llzk-array-to-scalar` pass because that pass will not scalarize + array types that are within a POD type. }]; } diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 09bf01177..4508ab4df 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -12,45 +12,50 @@ /// /// The steps of this transformation are as follows: /// -/// 0. Scan to find `llzk.nondet` ops that allocate uninitialized pods and replace them with -/// an equivalent `pod.new` +/// 0. Rewrite pod-typed `llzk.nondet` allocations into `pod.new` so later stages only need to +/// reason about POD storage through POD dialect operations. /// -/// 1. Run a dialect conversion that replaces direct `PodType` struct members with one scalar -/// member per record, replaces arrays whose element type is a POD with one parallel array per -/// scalar leaf record, and remembers how each original member was split. +/// 1. Run a dialect conversion that replaces pod-typed struct members with one scalar member per +/// POD record, replaces array-typed struct members whose element type is a POD with one parallel +/// array member per POD record, and remembers how each original member was split for the later +/// rewriting steps. /// -/// 2. Run a dialect conversion that does the following: +/// 2. Run a dialect conversion that splits arrays whose element type is a POD into parallel arrays +/// in `llzk.nondet`, `array.*`, `MemberReadOp`, `MemberWriteOp`, `FuncDefOp`, `CallOp`, and +/// `ReturnOp`. /// -/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1 -/// so they instead perform reads and writes on the new scalar or parallel-array members. The -/// transformation is local to the current op. Therefore, when replacing the `MemberReadOp` a -/// new pod is created locally and all uses of the `MemberReadOp` are replaced with the new -/// pod Value, then each scalar member read is followed by scalar write into the new pod. -/// Similarly, when replacing a `MemberWriteOp`, each element in the pod operand needs a -/// scalar read from the pod followed by a scalar write to the new member. Making only local -/// changes keeps this step simple and later steps will optimize. +/// 3. Run a dialect conversion that does the following: +/// +/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the pod-typed struct members split in +/// step 1 so they instead perform reads and writes on the new scalar members. This +/// transformation is local to the current op. Therefore, when replacing a `MemberReadOp`, a +/// new pod is created locally and all uses of the `MemberReadOp` are replaced with that new +/// POD value, then each scalar member read is followed by a scalar write into the new pod. +/// Similarly, when replacing a `MemberWriteOp`, each element in the pod operand needs a scalar +/// read from the pod followed by a scalar write to the new member. Making only local changes +/// keeps this step simple and lets later steps optimize away the temporary POD storage. /// /// - Remove optional initialization from `NewPodOp` and instead insert a list of `WritePodOp` /// immediately following. /// -/// - Split arrays whose element type is a POD into parallel arrays in `array.*`, -/// `FuncDefOp`, `CallOp`, and `ReturnOp`, then split remaining direct POD values to scalars -/// in `FuncDefOp`, `CallOp`, and `ReturnOp`. +/// - Split remaining direct POD values to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp`. +/// When a rewritten op still needs a POD-typed value locally, rebuild it with `pod.new` plus +/// `pod.write` so later cleanup can scalarize it away. /// -/// 3. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the +/// 4. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the /// access can be modeled as an SSA value flowing through the region boundary. This puts the /// pod accesses that mem2reg must eliminate into a parent block or loop-carried value. /// -/// 4. Run MLIR "sroa" pass to split each pod with `N` records into `N` pods with 1 record each +/// 5. Run MLIR "sroa" pass to split remaining POD allocations into single-record POD allocations /// (to prepare for the "mem2reg" pass because its API cannot split memory by itself). /// -/// 5. Run MLIR "mem2reg" pass to convert all single-record pod allocations and accesses into SSA +/// 6. Run MLIR "mem2reg" pass to convert all single-record POD allocations and accesses into SSA /// values. /// -/// 6. Remove pod allocations that become unread after memory promotion, then remove SSA values +/// 7. Remove POD allocations that become unread after memory promotion, then remove SSA values /// made dead by that cleanup. /// -/// ** Steps 4-6 are rerun while nested POD types are still being exposed, until a fixpoint. +/// Steps 5-7 are rerun while nested POD types are still being exposed, until a fixpoint. /// /// Note: This transformation imposes a "last write wins" semantics on pod records. If /// different/configurable semantics are added in the future, some additional transformation would @@ -235,12 +240,12 @@ static void forEachPodLeaf(PodType podTy, SmallVectorImpl &recordCha } } -/// If the given ArrayType has a POD element type, return it. +/// If the input ArrayType has a POD element type, return the input, else nullptr. inline static ArrayType splittablePodArray(ArrayType at) { return isa(at.getElementType()) ? at : nullptr; } -/// If the given Type is an ArrayType with a POD element type, return it. +/// If the input Type is an ArrayType with a POD element type, return the input, else nullptr. inline static ArrayType splittablePodArray(Type t) { if (ArrayType at = dyn_cast(t)) { return splittablePodArray(at); @@ -1372,7 +1377,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberReadOp::legal); - LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other pod ops\n";); + LLVM_DEBUG(llvm::dbgs() << "Begin step 3: update/split other pod ops\n";); return applyFullConversion(modOp, target, std::move(patterns)); } @@ -2205,7 +2210,7 @@ static LogicalResult step4(ModuleOp modOp) { LiftPodAccessesFromForLoopPattern, LiftPodAccessesFromWhileLoopPattern, FoldIfCarriedPodReadAfterWritePattern>(patterns.getContext()); - LLVM_DEBUG(llvm::dbgs() << "Begin step 3: refactor pod ops within SCF regions\n";); + LLVM_DEBUG(llvm::dbgs() << "Begin step 4: refactor pod ops within SCF regions\n";); return applyGreedily(modOp, std::move(patterns)); } From 6fb01256f40950aa00985b706a609a5d8e1a446c Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Mon, 22 Jun 2026 11:04:15 -0500 Subject: [PATCH 08/36] fix: Convert subarray ops for POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 68 ++++++++++++++++++- .../PodToScalar/array_extract_insert.llzk | 28 ++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 test/Transforms/PodToScalar/array_extract_insert.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 4508ab4df..144cf40b2 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -946,6 +946,67 @@ class SplitPodArrayLengthOp : public OpConversionPattern { } }; +/// Rewrite `array.extract` of an array-of-POD subarray into one extract per parallel leaf array. +class SplitPodArrayExtractArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(ExtractArrayOp op) { return !splittablePodArray(op.getResult().getType()); } + + LogicalResult matchAndRewrite( + ExtractArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + SmallVector splitResultTypes; + splitPodArrayTypeTo(op.getResult().getType(), splitResultTypes); + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + SmallVector replacements; + replacements.reserve(splitResultTypes.size()); + for (auto [splitArrRange, splitResultType] : + llvm::zip_equal(adaptor.getArrRef(), splitResultTypes)) { + replacements.push_back(rewriter.create( + op.getLoc(), llvm::cast(splitResultType), + getSingleConvertedValue(splitArrRange), indices + )); + } + + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + +/// Rewrite `array.insert` of an array-of-POD subarray into one insert per parallel leaf array. +class SplitPodArrayInsertArrayOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(InsertArrayOp op) { return !splittablePodArray(op.getRvalue().getType()); } + + LogicalResult matchAndRewrite( + InsertArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); + for (auto [splitArrRange, splitRvalueRange] : + llvm::zip_equal(adaptor.getArrRef(), adaptor.getRvalue())) { + rewriter.create( + op.getLoc(), getSingleConvertedValue(splitArrRange), indices, + getSingleConvertedValue(splitRvalueRange) + ); + } + + rewriter.eraseOp(op); + return success(); + } +}; + /// Rewrite a write to a split array-of-POD struct member into writes to each parallel array member. class SplitPodArrayInMemberWriteOp : public OpConversionPattern { SymbolTableCollection &tables; @@ -1044,8 +1105,9 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM RewritePatternSet patterns(ctx); patterns.add< SplitPodArrayNonDetOp, SplitPodArrayCreateArrayOp, SplitPodArrayReadArrayOp, - SplitPodArrayWriteArrayOp, SplitPodArrayInFuncDefOp, SplitPodArrayInReturnOp, - SplitPodArrayInCallOp, SplitPodArrayLengthOp>(typeConverter, ctx); + SplitPodArrayWriteArrayOp, SplitPodArrayExtractArrayOp, SplitPodArrayInsertArrayOp, + SplitPodArrayInFuncDefOp, SplitPodArrayInReturnOp, SplitPodArrayInCallOp, + SplitPodArrayLengthOp>(typeConverter, ctx); patterns.add( typeConverter, ctx, symTables, memberRepMap ); @@ -1056,6 +1118,8 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodArrayCreateArrayOp::legal); target.addDynamicallyLegalOp(SplitPodArrayReadArrayOp::legal); target.addDynamicallyLegalOp(SplitPodArrayWriteArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayExtractArrayOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInsertArrayOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInFuncDefOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInReturnOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInCallOp::legal); diff --git a/test/Transforms/PodToScalar/array_extract_insert.llzk b/test/Transforms/PodToScalar/array_extract_insert.llzk new file mode 100644 index 000000000..e4cf79245 --- /dev/null +++ b/test/Transforms/PodToScalar/array_extract_insert.llzk @@ -0,0 +1,28 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@lhs: index, @rhs: index]> +!PairRow = !array.type<2 x !Pair> +!PairMatrix = !array.type<2,2 x !Pair> +module attributes {llzk.lang} { + function.def @extract_then_insert(%src: !PairMatrix) -> !PairMatrix { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %row = array.extract %src[%c0] : !PairMatrix + %dst = array.new : !PairMatrix + array.insert %dst[%c1] = %row : !PairMatrix, !PairRow + function.return %dst : !PairMatrix + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @extract_then_insert(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>) -> (!array.type<2,2 x index>, !array.type<2,2 x index>) { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_0]]{{\[}}%[[VAL_2]]] : <2,2 x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_1]]{{\[}}%[[VAL_2]]] : <2,2 x index> +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x index> +// CHECK-NEXT: array.insert %[[VAL_6]]{{\[}}%[[VAL_3]]] = %[[VAL_4]] : <2,2 x index>, <2 x index> +// CHECK-NEXT: array.insert %[[VAL_7]]{{\[}}%[[VAL_3]]] = %[[VAL_5]] : <2,2 x index>, <2 x index> +// CHECK-NEXT: function.return %[[VAL_6]], %[[VAL_7]] : !array.type<2,2 x index>, !array.type<2,2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } From fd339e733c824e70fe2e080b7d921e6fd44ac9b5 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Mon, 22 Jun 2026 11:58:36 -0500 Subject: [PATCH 09/36] update test to ensure distinct types are preserved --- .../PodToScalar/array_extract_insert.llzk | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/Transforms/PodToScalar/array_extract_insert.llzk b/test/Transforms/PodToScalar/array_extract_insert.llzk index e4cf79245..227dc8ef8 100644 --- a/test/Transforms/PodToScalar/array_extract_insert.llzk +++ b/test/Transforms/PodToScalar/array_extract_insert.llzk @@ -1,6 +1,6 @@ // RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s -!Pair = !pod.type<[@lhs: index, @rhs: index]> +!Pair = !pod.type<[@lhs: index, @rhs: !felt.type]> !PairRow = !array.type<2 x !Pair> !PairMatrix = !array.type<2,2 x !Pair> module attributes {llzk.lang} { @@ -14,15 +14,17 @@ module attributes {llzk.lang} { } } // CHECK-LABEL: module attributes {llzk.lang} { -// CHECK-NEXT: function.def @extract_then_insert(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>) -> (!array.type<2,2 x index>, !array.type<2,2 x index>) { +// CHECK-NEXT: function.def @extract_then_insert +// CHECK-SAME: (%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x !felt.type>) +// CHECK-SAME: -> (!array.type<2,2 x index>, !array.type<2,2 x !felt.type>) { // CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index // CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index // CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_0]]{{\[}}%[[VAL_2]]] : <2,2 x index> -// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_1]]{{\[}}%[[VAL_2]]] : <2,2 x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_1]]{{\[}}%[[VAL_2]]] : <2,2 x !felt.type> // CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x index> -// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new : <2,2 x !felt.type> // CHECK-NEXT: array.insert %[[VAL_6]]{{\[}}%[[VAL_3]]] = %[[VAL_4]] : <2,2 x index>, <2 x index> -// CHECK-NEXT: array.insert %[[VAL_7]]{{\[}}%[[VAL_3]]] = %[[VAL_5]] : <2,2 x index>, <2 x index> -// CHECK-NEXT: function.return %[[VAL_6]], %[[VAL_7]] : !array.type<2,2 x index>, !array.type<2,2 x index> +// CHECK-NEXT: array.insert %[[VAL_7]]{{\[}}%[[VAL_3]]] = %[[VAL_5]] : <2,2 x !felt.type>, <2 x !felt.type> +// CHECK-NEXT: function.return %[[VAL_6]], %[[VAL_7]] : !array.type<2,2 x index>, !array.type<2,2 x !felt.type> // CHECK-NEXT: } // CHECK-NEXT: } From bc05e5fb36637945f9ffbb96f4cfcc38a2a343d7 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Mon, 22 Jun 2026 16:20:54 -0500 Subject: [PATCH 10/36] fix: avoid invalid member ref and fully handle array of pod --- .../POD/Transforms/PodToScalarPass.cpp | 580 +++++++++++++----- .../member_with_nested_pod_array.llzk | 63 ++ 2 files changed, 483 insertions(+), 160 deletions(-) create mode 100644 test/Transforms/PodToScalar/member_with_nested_pod_array.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 144cf40b2..e16c31e2c 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -27,20 +27,17 @@ /// 3. Run a dialect conversion that does the following: /// /// - Replace `MemberReadOp` and `MemberWriteOp` targeting the pod-typed struct members split in -/// step 1 so they instead perform reads and writes on the new scalar members. This -/// transformation is local to the current op. Therefore, when replacing a `MemberReadOp`, a -/// new pod is created locally and all uses of the `MemberReadOp` are replaced with that new -/// POD value, then each scalar member read is followed by a scalar write into the new pod. -/// Similarly, when replacing a `MemberWriteOp`, each element in the pod operand needs a scalar -/// read from the pod followed by a scalar write to the new member. Making only local changes -/// keeps this step simple and lets later steps optimize away the temporary POD storage. +/// step 1 so they instead perform reads and writes on the new scalar members. Reads and writes +/// are tracked through virtual POD placeholders so the conversion can keep propagating scalar +/// leaves instead of re-introducing aggregate POD storage. /// /// - Remove optional initialization from `NewPodOp` and instead insert a list of `WritePodOp` /// immediately following. /// /// - Split remaining direct POD values to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp`. -/// When a rewritten op still needs a POD-typed value locally, rebuild it with `pod.new` plus -/// `pod.write` so later cleanup can scalarize it away. +/// When a rewritten op still needs POD contents locally, keep them in the same virtual +/// placeholder form for as long as possible and only materialize concrete `pod.write` +/// operations as a fallback for unresolved uses. /// /// 4. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the /// access can be modeled as an SSA value flowing through the region boundary. This puts the @@ -104,6 +101,8 @@ #include #include +#include + // Include the generated base pass class definitions. namespace llzk::pod { #define GEN_PASS_DEF_PODTOSCALARPASS @@ -188,15 +187,76 @@ template static bool containsSplittablePodType(ValueTypeRange ty return false; } +/// If the input ArrayType has a POD element type, return the input, else nullptr. +inline static ArrayType splittablePodArray(ArrayType at) { + return isa(at.getElementType()) ? at : nullptr; +} + +/// If the input Type is an ArrayType with a POD element type, return the input, else nullptr. +inline static ArrayType splittablePodArray(Type t) { + if (ArrayType at = dyn_cast(t)) { + return splittablePodArray(at); + } + return nullptr; +} + +/// Return the flattened leaf type addressed by `recordChain` within `type`. +static Type getFlattenedTypeAlongPath(Type type, ArrayRef recordChain) { + if (recordChain.empty()) { + return type; + } + + if (PodType podTy = dyn_cast(type)) { + Type nextType = podTy.getRecordMap().lookup(recordChain.front().getValue()); + assert(nextType && "record path must exist in the containing POD"); + return getFlattenedTypeAlongPath(nextType, recordChain.drop_front()); + } + + if (ArrayType arrTy = splittablePodArray(type)) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + Type nextType = elemPodTy.getRecordMap().lookup(recordChain.front().getValue()); + assert(nextType && "record path must exist in the POD array element type"); + return arrTy.cloneWith(getFlattenedTypeAlongPath(nextType, recordChain.drop_front())); + } + + llvm_unreachable("record path cannot continue through a non-POD leaf"); +} + +/// Visit each non-POD leaf record in `podTy`, providing its record-name chain and leaf type. +template +static void forEachPodLeaf(PodType podTy, SmallVectorImpl &recordChain, Fn &&callback) { + std::function walk = [&](Type type) { + if (PodType nestedPodTy = llvm::dyn_cast(type)) { + for (RecordAttr record : nestedPodTy.getRecords()) { + recordChain.push_back(record.getName()); + walk(record.getType()); + recordChain.pop_back(); + } + } else if (ArrayType arrTy = splittablePodArray(type)) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + for (RecordAttr record : elemPodTy.getRecords()) { + recordChain.push_back(record.getName()); + walk(arrTy.cloneWith(record.getType())); + recordChain.pop_back(); + } + } else { + callback(RecordChain(recordChain), type); + } + }; + + walk(podTy); +} + /// If the given Type is a PodType that can be split into scalars, append `collect` with all of /// the scalar types that result from splitting the PodType. Otherwise, just push the `Type`. size_t splitPodTypeTo(Type t, SmallVector &collect) { if (PodType pt = splittablePod(t)) { - auto records = pt.getRecords(); - for (RecordAttr record : records) { - collect.push_back(record.getType()); - } - return records.size(); + SmallVector recordChain; + size_t originalSize = collect.size(); + forEachPodLeaf(pt, recordChain, [&collect](RecordChain, Type leafType) { + collect.push_back(leafType); + }); + return collect.size() - originalSize; } else { collect.push_back(t); return 1; @@ -226,33 +286,6 @@ splitPodType(TypeCollection types, SmallVector *originalIdxToSize = null return collect; } -/// Visit each non-POD leaf record in `podTy`, providing its record-name chain and leaf type. -template -static void forEachPodLeaf(PodType podTy, SmallVectorImpl &recordChain, Fn &&callback) { - for (RecordAttr record : podTy.getRecords()) { - recordChain.push_back(record.getName()); - if (PodType nestedPodTy = dyn_cast(record.getType())) { - forEachPodLeaf(nestedPodTy, recordChain, callback); - } else { - callback(RecordChain(recordChain), record.getType()); - } - recordChain.pop_back(); - } -} - -/// If the input ArrayType has a POD element type, return the input, else nullptr. -inline static ArrayType splittablePodArray(ArrayType at) { - return isa(at.getElementType()) ? at : nullptr; -} - -/// If the input Type is an ArrayType with a POD element type, return the input, else nullptr. -inline static ArrayType splittablePodArray(Type t) { - if (ArrayType at = dyn_cast(t)) { - return splittablePodArray(at); - } - return nullptr; -} - /// Return `true` iff any type in the range is an array whose element type is a POD. inline static bool containsSplittablePodArrayType(ArrayRef types) { return llvm::any_of(types, [](Type t) { return splittablePodArray(t); }); @@ -329,16 +362,16 @@ static SmallVector getSplitPodArrayRecordNameSuffixes(Type type) { /// Create a `pod.read` for one record of `podRef`. inline static ReadPodOp -genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &rewriter) { +genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &bldr) { Type resultType = llvm::cast(podRef.getType()).getRecordMap().lookup(recordName.getValue()); - return rewriter.create(loc, resultType, podRef, recordName); + return bldr.create(loc, resultType, podRef, recordName); } /// Create a `pod.write` for one record of `podRef`. inline static WritePodOp -genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuilder &rewriter) { - return rewriter.create(loc, podRef, recordName, value); +genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuilder &bldr) { + return bldr.create(loc, podRef, recordName, value); } /// Return the single converted value from a 1:N adaptor range. @@ -357,51 +390,190 @@ static SmallVector flattenConvertedValues(RangeOfRanges ranges) { return values; } -/// Read a nested POD leaf by following each record name in `recordChain`. +/// Generate `arith.constant` indices for one static array element position. +static SmallVector genArrayIndexConstants(ArrayAttr index, Location loc, OpBuilder &bldr) { + SmallVector indices; + for (Attribute attr : index) { + assert(llvm::isa(attr) && "array index must be an integer attribute"); + indices.push_back(bldr.create(loc, llvm::cast(attr))); + } + return indices; +} + +/// Create an `array.read` for one concrete element or subarray. +inline static ReadArrayOp +genArrayRead(Location loc, Value arrayRef, ArrayAttr index, OpBuilder &bldr) { + Type t = arrayRef.getType(); + assert(llvm::isa(t) && "array.read must target an array type"); + return bldr.create( + loc, llvm::cast(t).getElementType(), arrayRef, + genArrayIndexConstants(index, loc, bldr) + ); +} + +/// Create an `array.write` for one concrete element or subarray. +inline static WriteArrayOp +genArrayWrite(Location loc, Value arrayRef, ArrayAttr index, Value value, OpBuilder &bldr) { + return bldr.create(loc, arrayRef, genArrayIndexConstants(index, loc, bldr), value); +} + +/// Read one flattened POD leaf, including leaves that live inside an array-of-POD record. static Value -genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &rewriter) { - Value value = podRef; - for (StringAttr attr : recordChain.nameList) { - value = genRead(loc, value, attr, rewriter); +genReadAlongPath(Location loc, Value value, ArrayRef recordChain, OpBuilder &bldr) { + if (recordChain.empty()) { + return value; } - return value; + + Type valueType = value.getType(); + if (llvm::isa(valueType)) { + Value nextValue = genRead(loc, value, recordChain.front(), bldr); + return genReadAlongPath(loc, nextValue, recordChain.drop_front(), bldr); + } + + if (ArrayType arrTy = splittablePodArray(valueType)) { + assert(arrTy.hasStaticShape() && "nested array-of-POD scalarization requires a static shape"); + auto splitArrTy = llvm::cast(getFlattenedTypeAlongPath(valueType, recordChain)); + auto subIndices = arrTy.getSubelementIndices(); + assert(subIndices && "static-shape arrays must provide subelement indices"); + + Value splitArray = bldr.create(loc, splitArrTy); + for (ArrayAttr index : *subIndices) { + Value element = genArrayRead(loc, value, index, bldr); + Value leafValue = genReadAlongPath(loc, element, recordChain, bldr); + genArrayWrite(loc, splitArray, index, leafValue, bldr); + } + return splitArray; + } + + llvm_unreachable("record path cannot continue through a non-POD leaf"); +} + +/// Read a flattened POD leaf by following each record name in `recordChain`. +inline static Value +genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &bldr) { + return genReadAlongPath(loc, podRef, ArrayRef(recordChain.nameList), bldr); } /// Reconstruct a POD record from the leaf values collected while splitting nested accesses. static Value rebuildFlattenedPodRecord( Location loc, Type recordType, SmallVectorImpl &recordChain, - const DenseMap &leafValues, ConversionPatternRewriter &rewriter + const DenseMap &leafValues, OpBuilder &bldr ) { if (PodType nestedPodTy = dyn_cast(recordType)) { - NewPodOp nestedPod = rewriter.create(loc, nestedPodTy); + NewPodOp nestedPod = bldr.create(loc, nestedPodTy); for (RecordAttr record : nestedPodTy.getRecords()) { recordChain.push_back(record.getName()); Value recordValue = - rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, rewriter); - genWrite(loc, nestedPod, record.getName(), recordValue, rewriter); + rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, bldr); + genWrite(loc, nestedPod, record.getName(), recordValue, bldr); recordChain.pop_back(); } return nestedPod; } + if (ArrayType arrTy = splittablePodArray(recordType)) { + assert(arrTy.hasStaticShape() && "nested array-of-POD scalarization requires a static shape"); + auto elemPodTy = llvm::cast(arrTy.getElementType()); + auto subIndices = arrTy.getSubelementIndices(); + assert(subIndices && "static-shape arrays must provide subelement indices"); + + Value rebuiltArray = bldr.create(loc, arrTy); + for (ArrayAttr index : *subIndices) { + DenseMap elementLeafValues; + SmallVector elementRecordChain; + forEachPodLeaf(elemPodTy, elementRecordChain, [&](RecordChain id, Type) { + SmallVector fullChain(recordChain.begin(), recordChain.end()); + llvm::append_range(fullChain, id.nameList); + auto it = leafValues.find(RecordChain(fullChain)); + assert(it != leafValues.end() && "missing flattened POD array leaf value"); + elementLeafValues[id] = genArrayRead(loc, it->second, index, bldr); + }); + + NewPodOp elementPod = bldr.create(loc, elemPodTy); + SmallVector nestedChain; + for (RecordAttr record : elemPodTy.getRecords()) { + nestedChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(loc, record.getType(), nestedChain, elementLeafValues, bldr); + genWrite(loc, elementPod, record.getName(), recordValue, bldr); + nestedChain.pop_back(); + } + genArrayWrite(loc, rebuiltArray, index, elementPod, bldr); + } + return rebuiltArray; + } + auto it = leafValues.find(RecordChain(recordChain)); assert(it != leafValues.end() && "missing flattened POD leaf value"); return it->second; } +/// Populate a POD value from its flattened leaf values. +static void populateFlattenedPodValue( + Location loc, Value podValue, PodType podTy, const DenseMap &leafValues, + OpBuilder &bldr +) { + SmallVector recordChain; + for (RecordAttr record : podTy.getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, bldr); + genWrite(loc, podValue, record.getName(), recordValue, bldr); + recordChain.pop_back(); + } +} + +using VirtualPodLeafMap = DenseMap; +using VirtualPodValueMap = DenseMap; + +/// Return the flattened leaf values for `podValue` when it is tracked as a virtual POD. +static const VirtualPodLeafMap * +lookupVirtualPodLeafMap(Value podValue, const VirtualPodValueMap &virtualPods) { + auto it = virtualPods.find(podValue); + return it != virtualPods.end() ? &it->second : nullptr; +} + +/// Collect flattened POD leaf values in canonical traversal order. +static SmallVector +orderedVirtualPodLeafValues(PodType podTy, const VirtualPodLeafMap &leafValues) { + SmallVector orderedValues; + SmallVector recordChain; + forEachPodLeaf(podTy, recordChain, [&leafValues, &orderedValues](RecordChain id, Type) { + auto it = leafValues.find(id); + assert(it != leafValues.end() && "missing virtual POD leaf value"); + orderedValues.push_back(it->second); + }); + return orderedValues; +} + +/// Materialize the tracked contents of a virtual POD into concrete `pod.write` operations. +inline static void +materializeVirtualPod(NewPodOp pod, const VirtualPodLeafMap &leafValues, OpBuilder &bldr) { + populateFlattenedPodValue(pod.getLoc(), pod, pod.getType(), leafValues, bldr); +} + +/// Return `true` iff a read from a virtual POD can be resolved without materializing it. +static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &virtualPods) { + if (!lookupVirtualPodLeafMap(op.getPodRef(), virtualPods)) { + return false; + } + Type recType = llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); + return llvm::isa(recType) || !splittablePodArray(recType); +} + /// Return the suffixes to append to a function arg/result name when splitting the given type. static SmallVector getSplitRecordNameSuffixes(Type type) { SmallVector suffixes; if (PodType pt = splittablePod(type)) { - suffixes.reserve(pt.getRecords().size()); - for (RecordAttr record : pt.getRecords()) { - StringRef name = record.getName().getValue(); - std::string result; - result.reserve(name.size() + 1); - result.push_back('.'); - result.append(name.data(), name.size()); - suffixes.push_back(result); - } + SmallVector recordChain; + forEachPodLeaf(pt, recordChain, [&suffixes](RecordChain id, Type) { + std::string suffix; + llvm::raw_string_ostream os(suffix); + for (StringAttr recordName : id.nameList) { + os << '.' << recordName.getValue(); + } + suffixes.push_back(std::move(suffix)); + }); } return suffixes; } @@ -410,12 +582,19 @@ static SmallVector getSplitRecordNameSuffixes(Type type) { // add the original operand to the list. static void processInputOperand( Location loc, Value operand, SmallVector &newOperands, - ConversionPatternRewriter &rewriter + ConversionPatternRewriter &rewriter, const VirtualPodValueMap *virtualPods = nullptr ) { if (PodType pt = splittablePod(operand.getType())) { - for (RecordAttr record : pt.getRecords()) { - newOperands.push_back(genRead(loc, operand, record.getName(), rewriter)); + if (virtualPods) { + if (const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(operand, *virtualPods)) { + llvm::append_range(newOperands, orderedVirtualPodLeafValues(pt, *leafValues)); + return; + } } + SmallVector recordChain; + forEachPodLeaf(pt, recordChain, [&](RecordChain id, Type) { + newOperands.push_back(genReadAlongPath(loc, operand, id, rewriter)); + }); } else { newOperands.push_back(operand); } @@ -425,11 +604,11 @@ static void processInputOperand( /// and update the op to use the new operands. static void processInputOperands( ValueRange operands, MutableOperandRange outputOpRef, Operation *op, - ConversionPatternRewriter &rewriter + ConversionPatternRewriter &rewriter, const VirtualPodValueMap *virtualPods = nullptr ) { SmallVector newOperands; for (Value v : operands) { - processInputOperand(op->getLoc(), v, newOperands, rewriter); + processInputOperand(op->getLoc(), v, newOperands, rewriter, virtualPods); } rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() { outputOpRef.assign(ValueRange(newOperands)); @@ -1167,8 +1346,11 @@ class SplitInitFromNewPodOp : public OpConversionPattern { /// rest of the function can continue to use POD values until later cleanup passes scalarize those /// local temporaries away. class SplitPodInFuncDefOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + public: - using OpConversionPattern::OpConversionPattern; + SplitPodInFuncDefOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} inline static bool legal(FuncDefOp op) { return !containsSplittablePodType(op.getArgumentTypes()) && @@ -1183,6 +1365,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { SmallVector originalInputIdxToSize, originalResultIdxToSize; SplitFunctionNameInfo inputNameInfo; SplitFunctionNameInfo resultNameInfo; + VirtualPodValueMap &virtualPods; protected: SmallVector convertInputs(ArrayRef origTypes) override { @@ -1218,18 +1401,19 @@ class SplitPodInFuncDefOp : public OpConversionPattern { Value oldV = entryBlock.getArgument(i); if (PodType pt = splittablePod(oldV.getType())) { Location loc = oldV.getLoc(); - // Generate `NewPodOp` and replace uses of the argument with it. auto newPod = rewriter.create(loc, pt); rewriter.replaceAllUsesWith(oldV, newPod); // Remove the argument from the block entryBlock.eraseArgument(i); - // For all indices in the PodType (i.e., the element count), generate a new - // block argument and a write of that argument to the new pod. - for (RecordAttr record : pt.getRecords()) { - BlockArgument newArg = entryBlock.insertArgument(i, record.getType(), loc); - genWrite(loc, newPod, record.getName(), newArg, rewriter); + + DenseMap leafValues; + SmallVector recordChain; + forEachPodLeaf(pt, recordChain, [&](RecordChain id, Type leafType) { + BlockArgument newArg = entryBlock.insertArgument(i, leafType, loc); + leafValues[id] = newArg; ++i; - } + }); + virtualPods[newPod] = std::move(leafValues); } else { ++i; } @@ -1237,7 +1421,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { } public: - Impl(FuncDefOp op) { + Impl(FuncDefOp op, VirtualPodValueMap &virtualPodMap) : virtualPods(virtualPodMap) { inputNameInfo = collectSplitFunctionNameInfo(op.getArgumentTypes(), [&op](unsigned i) { return op.getArgNameAttr(i); }, getSplitRecordNameSuffixes); @@ -1248,7 +1432,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { ); } }; - Impl(op).convert(op, rewriter); + Impl(op, virtualPods).convert(op, rewriter); } }; @@ -1258,8 +1442,11 @@ class SplitPodInFuncDefOp : public OpConversionPattern { /// are returned as one SSA value per record, using local `pod.read` operations to extract the /// scalar pieces immediately before the return. class SplitPodInReturnOp : public OpConversionPattern { + const VirtualPodValueMap &virtualPods; + public: - using OpConversionPattern::OpConversionPattern; + SplitPodInReturnOp(MLIRContext *ctx, const VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} inline static bool legal(ReturnOp op) { return !containsSplittablePodType(op.getOperands().getTypes()); @@ -1268,13 +1455,16 @@ class SplitPodInReturnOp : public OpConversionPattern { LogicalResult match(ReturnOp op) const override { return failure(legal(op)); } void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter); + processInputOperands( + adaptor.getOperands(), op.getOperandsMutable(), op, rewriter, &virtualPods + ); } }; /// Rebuild a call with split scalar results, then reconstruct POD-typed results locally. static CallOp newCallOpWithSplitResults( - CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter + CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, + VirtualPodValueMap &virtualPods ) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(oldCall); @@ -1289,15 +1479,15 @@ static CallOp newCallOpWithSplitResults( for (Value oldVal : oldResults) { if (PodType pt = splittablePod(oldVal.getType())) { Location loc = oldVal.getLoc(); - // Generate `NewPodOp` and replace uses of the result with it. - auto newPod = rewriter.create(loc, pt); + DenseMap leafValues; + SmallVector recordChain; + forEachPodLeaf(pt, recordChain, [&leafValues, &newResults](RecordChain id, Type) { + leafValues[id] = *newResults; + ++newResults; + }); + NewPodOp newPod = rewriter.create(loc, pt); + virtualPods[newPod] = std::move(leafValues); rewriter.replaceAllUsesWith(oldVal, newPod); - - // For each record in the PodType, write the next result from the new CallOp to the new pod. - for (RecordAttr record : pt.getRecords()) { - genWrite(loc, newPod, record.getName(), *newResults, rewriter); - newResults++; - } } else { rewriter.replaceAllUsesWith(oldVal, *newResults); newResults++; @@ -1316,8 +1506,11 @@ static CallOp newCallOpWithSplitResults( /// the original POD-typed uses in the caller until later optimization passes remove the temporary /// POD allocations. class SplitPodInCallOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + public: - using OpConversionPattern::OpConversionPattern; + SplitPodInCallOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} inline static bool legal(CallOp op) { return !containsSplittablePodType(op.getArgOperands().getTypes()) && @@ -1328,84 +1521,137 @@ class SplitPodInCallOp : public OpConversionPattern { void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Create new CallOp with split results first so, then process its inputs to split types - CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter); + CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter, virtualPods); processInputOperands( - newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter + newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter, &virtualPods ); } }; -/// State used while rebuilding a POD from flattened struct-member leaves. -struct RebuildPodReadState { - NewPodOp pod; - DenseMap leafValues; -}; - /// Rewrite a write to a pod-typed struct member into writes to the corresponding scalar leaves. -class SplitPodInMemberWriteOp : public SplitAggregateInMemberRefOp< - SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain> { +class SplitPodInMemberWriteOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + const VirtualPodValueMap &virtualPods; + public: - using SplitAggregateInMemberRefOp< - SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain>::SplitAggregateInMemberRefOp; + SplitPodInMemberWriteOp( + MLIRContext *ctx, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap, + const VirtualPodValueMap &virtualPodMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap), + virtualPods(virtualPodMap) {} static bool legal(MemberWriteOp op) { return !containsSplittablePodType(op.getVal().getType()); } - static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) { return nullptr; } + LogicalResult match(MemberWriteOp op) const override { return failure(legal(op)); } - static void forId( - Location loc, void *&, RecordChain id, MemberInfo newMember, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter - ) { - Value scalarRead = genReadAlongPath(loc, adaptor.getVal(), id, rewriter); - rewriter.create( - loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead - ); + void + rewrite(MemberWriteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); + + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + const VirtualPodLeafMap *virtualLeafValues = + lookupVirtualPodLeafMap(adaptor.getVal(), virtualPods); + + for (auto [id, newMember] : idToMember) { + Value scalarValue = virtualLeafValues + ? virtualLeafValues->at(id) + : genReadAlongPath(op.getLoc(), adaptor.getVal(), id, rewriter); + rewriter.create( + op.getLoc(), adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarValue + ); + } + rewriter.eraseOp(op); } }; /// Rewrite a read from a pod-typed struct member into reads from the corresponding scalar leaves. -class SplitPodInMemberReadOp - : public SplitAggregateInMemberRefOp< - SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState, RecordChain> { +class SplitPodInMemberReadOp : public OpConversionPattern { + SymbolTableCollection &tables; + const MemberReplacementMap &repMapRef; + VirtualPodValueMap &virtualPods; + public: - using SplitAggregateInMemberRefOp< - SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState, - RecordChain>::SplitAggregateInMemberRefOp; + SplitPodInMemberReadOp( + MLIRContext *ctx, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap, + VirtualPodValueMap &virtualPodMap + ) + : OpConversionPattern(ctx), tables(symTables), repMapRef(memberRepMap), + virtualPods(virtualPodMap) {} static bool legal(MemberReadOp op) { return !containsSplittablePodType(op.getResult().getType()); } - static RebuildPodReadState genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) { - RebuildPodReadState state; - state.pod = rewriter.create(op.getLoc(), llvm::cast(op.getType())); - rewriter.replaceAllUsesWith(op, state.pod); - return state; - } + LogicalResult match(MemberReadOp op) const override { return failure(legal(op)); } - static void forId( - Location loc, RebuildPodReadState &state, RecordChain id, MemberInfo newMember, - OpAdaptor adaptor, ConversionPatternRewriter &rewriter - ) { - Value scalarRead = rewriter.create( - loc, newMember.second, adaptor.getComponent(), newMember.first - ); - state.leafValues[id] = scalarRead; - } + void + rewrite(MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + StructType tgtStructTy = llvm::cast(op.getOperation()).getStructType(); + auto tgtStructDef = tgtStructTy.getDefinition(tables, op); + assert(succeeded(tgtStructDef)); - static void finalize( - MemberReadOp op, RebuildPodReadState &state, OpAdaptor, ConversionPatternRewriter &rewriter - ) { - auto podTy = llvm::cast(op.getType()); - SmallVector recordChain; - for (RecordAttr record : podTy.getRecords()) { - recordChain.push_back(record.getName()); - Value recordValue = rebuildFlattenedPodRecord( - op.getLoc(), record.getType(), recordChain, state.leafValues, rewriter + const LocalMemberReplacementMap &idToMember = + repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); + + VirtualPodLeafMap leafValues; + for (auto [id, newMember] : idToMember) { + leafValues[id] = rewriter.create( + op.getLoc(), newMember.second, adaptor.getComponent(), newMember.first ); - genWrite(op.getLoc(), state.pod, record.getName(), recordValue, rewriter); - recordChain.pop_back(); } + + NewPodOp pod = rewriter.create(op.getLoc(), llvm::cast(op.getType())); + virtualPods[pod] = std::move(leafValues); + rewriter.replaceOp(op, pod); + } +}; + +/// Resolve reads from a virtual POD placeholder without materializing the whole aggregate. +class ResolveVirtualPodReadOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + +public: + ResolveVirtualPodReadOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + LogicalResult matchAndRewrite( + ReadPodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(adaptor.getPodRef(), virtualPods); + if (!leafValues) { + return failure(); + } + + SmallVector prefix {op.getRecordNameAttr()}; + Type recordType = + llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); + assert(recordType && "record must exist in POD type"); + + if (PodType nestedPodTy = llvm::dyn_cast(recordType)) { + VirtualPodLeafMap nestedLeafValues; + SmallVector nestedRecordChain; + forEachPodLeaf(nestedPodTy, nestedRecordChain, [&](RecordChain id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + nestedLeafValues[id] = leafValues->at(RecordChain(fullChain)); + }); + NewPodOp pod = rewriter.create(op.getLoc(), nestedPodTy); + virtualPods[pod] = std::move(nestedLeafValues); + rewriter.replaceOp(op, pod); + return success(); + } + + if (splittablePodArray(recordType)) { + return failure(); + } + + rewriter.replaceOp(op, leafValues->at(RecordChain(prefix))); + return success(); } }; @@ -1414,23 +1660,15 @@ class SplitPodInMemberReadOp static LogicalResult step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { MLIRContext *ctx = modOp.getContext(); + VirtualPodValueMap virtualPods; RewritePatternSet patterns(ctx); - patterns.add< - // clang-format off - SplitInitFromNewPodOp, - SplitPodInFuncDefOp, - SplitPodInReturnOp, - SplitPodInCallOp - // clang-format on - >(ctx); - - patterns.add< - // clang-format off - SplitPodInMemberWriteOp, - SplitPodInMemberReadOp - // clang-format on - >(ctx, symTables, memberRepMap); + patterns.add(ctx); + patterns.add(ctx, virtualPods); + patterns.add( + ctx, symTables, memberRepMap, virtualPods + ); + patterns.add(ctx, virtualPods); ConversionTarget target(*ctx); baseTargetSetup(target); @@ -1440,9 +1678,26 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodInCallOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberReadOp::legal); + target.addDynamicallyLegalOp([&virtualPods](ReadPodOp op) { + return !canResolveVirtualPodRead(op, virtualPods); + }); LLVM_DEBUG(llvm::dbgs() << "Begin step 3: update/split other pod ops\n";); - return applyFullConversion(modOp, target, std::move(patterns)); + if (failed(applyFullConversion(modOp, target, std::move(patterns)))) { + return failure(); + } + + OpBuilder builder(ctx); + for (auto &[podValue, leafValues] : virtualPods) { + if (podValue.use_empty()) { + continue; + } + if (auto newPod = llvm::dyn_cast(podValue.getDefiningOp())) { + builder.setInsertionPointAfter(newPod); + materializeVirtualPod(newPod, leafValues, builder); + } + } + return success(); } /// Return whether the given read/write access targets the same POD record. @@ -2383,6 +2638,11 @@ class PassImpl : public llzk::pod::impl::PodToScalarPassBase { // Cleanup allocations made dead by memory promotion and other dead SSA values. OpPassManager cleanupPM(ModuleOp::getOperationName()); + cleanupPM.addPass(createRemoveUnusedDiscardableAllocationsPass( + RemoveUnusedDiscardableAllocationsPassOptions { + .allocatorOpName = CreateArrayOp::getOperationName().str() + } + )); cleanupPM.addPass(createRemoveUnusedDiscardableAllocationsPass( RemoveUnusedDiscardableAllocationsPassOptions { .allocatorOpName = NewPodOp::getOperationName().str() diff --git a/test/Transforms/PodToScalar/member_with_nested_pod_array.llzk b/test/Transforms/PodToScalar/member_with_nested_pod_array.llzk new file mode 100644 index 000000000..5c675fe47 --- /dev/null +++ b/test/Transforms/PodToScalar/member_with_nested_pod_array.llzk @@ -0,0 +1,63 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + + function.def @compute(%p: !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]>) + -> !struct.type<@S> { + %self = struct.new : !struct.type<@S> + struct.writem %self[@m] = %p : !struct.type<@S>, !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + %loaded = struct.readm %self[@m] : !struct.type<@S>, !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + %b = pod.read %loaded[@b] : !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]>, index + function.return %self : !struct.type<@S> + } + + function.def @constrain( + %self: !struct.type<@S>, %p: !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + ) { + %loaded = struct.readm %self[@m] : !struct.type<@S>, !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]> + %b = pod.read %loaded[@b] : !pod.type<[@a: !array.type<2 x !pod.type<[@x: index, @y: !felt.type]>>, @b: index]>, index + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-DAG: struct.member @m_a_x : !array.type<2 x index> +// CHECK-DAG: struct.member @m_a_y : !array.type<2 x !felt.type> +// CHECK-DAG: struct.member @m_b : index +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, %[[VAL_2:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-DAG: struct.writem %[[VAL_3]][@m_a_y] = %[[VAL_1]] : <@S>, !array.type<2 x !felt.type> +// CHECK-DAG: struct.writem %[[VAL_3]][@m_b] = %[[VAL_2]] : <@S>, index +// CHECK-DAG: struct.writem %[[VAL_3]][@m_a_x] = %[[VAL_0]] : <@S>, !array.type<2 x index> +// CHECK-DAG: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@m_a_y] : <@S>, !array.type<2 x !felt.type> +// CHECK-DAG: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@m_b] : <@S>, index +// CHECK-DAG: %[[VAL_6:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@m_a_x] : <@S>, !array.type<2 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_6]]{{\[}}%[[VAL_7]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_4]]{{\[}}%[[VAL_9]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_6]]{{\[}}%[[VAL_11]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_4]]{{\[}}%[[VAL_13]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: function.return %[[VAL_3]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_15:[0-9a-zA-Z_\.]+]]: !struct.type<@S>, %[[VAL_16:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_17:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, %[[VAL_18:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-DAG: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_15]][@m_a_y] : <@S>, !array.type<2 x !felt.type> +// CHECK-DAG: %[[VAL_20:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_15]][@m_b] : <@S>, index +// CHECK-DAG: %[[VAL_21:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_15]][@m_a_x] : <@S>, !array.type<2 x index> +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_21]]{{\[}}%[[VAL_22]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_24]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: %[[VAL_26:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_27:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_21]]{{\[}}%[[VAL_26]]] : <2 x index>, index +// CHECK-NEXT: %[[VAL_28:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_29:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_19]]{{\[}}%[[VAL_28]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From 3166d4520e6ad5d66138f3967ae538f625ae6ec9 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 10:41:29 -0500 Subject: [PATCH 11/36] fix: Avoid assuming one backing array for `array.len` --- .../POD/Transforms/PodToScalarPass.cpp | 41 +++++++++- test/Transforms/PodToScalar/array_length.llzk | 82 +++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 test/Transforms/PodToScalar/array_length.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index e16c31e2c..afa5ec206 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -380,6 +380,40 @@ inline static Value getSingleConvertedValue(ValueRange values) { return values.front(); } +/// Materialize a scalar array value that preserves the shape of `originalArrTy`. +/// +/// This is used as a shape-only carrier for `array.len` when an array-of-POD splits to +/// zero parallel leaf arrays (for example, `!array.type<... x !pod.type<[]>>`). +static Value materializeArrayLengthCarrier( + Value originalArrRef, ArrayType originalArrTy, Location loc, ConversionPatternRewriter &rewriter +) { + ArrayType carrierTy = originalArrTy.cloneWith(IndexType::get(rewriter.getContext())); + + if (auto create = originalArrRef.getDefiningOp()) { + if (create.getMapOperands().empty()) { + return rewriter.create(loc, carrierTy); + } + + SmallVector mapOperands; + mapOperands.reserve(create.getMapOperands().size()); + for (OperandRange mapOperandGroup : create.getMapOperands()) { + mapOperands.push_back(mapOperandGroup); + } + return rewriter.create( + loc, carrierTy, mapOperands, create.getNumDimsPerMapAttr() + ); + } + + bool hasAffineDims = llvm::any_of(originalArrTy.getDimensionSizes(), [](Attribute dimSize) { + return llvm::isa(dimSize); + }); + if (!hasAffineDims) { + return rewriter.create(loc, carrierTy); + } + + return rewriter.create(loc, carrierTy); +} + /// Flatten a range of converted value ranges into a single list of values. template static SmallVector flattenConvertedValues(RangeOfRanges ranges) { @@ -1118,8 +1152,13 @@ class SplitPodArrayLengthOp : public OpConversionPattern { if (legal(op)) { return failure(); } + Value arrRef = adaptor.getArrRef().empty() + ? materializeArrayLengthCarrier( + op.getArrRef(), op.getArrRefType(), op.getLoc(), rewriter + ) + : adaptor.getArrRef().front(); rewriter.replaceOpWithNewOp( - op, getSingleConvertedValue(adaptor.getArrRef()), getSingleConvertedValue(adaptor.getDim()) + op, arrRef, getSingleConvertedValue(adaptor.getDim()) ); return success(); } diff --git a/test/Transforms/PodToScalar/array_length.llzk b/test/Transforms/PodToScalar/array_length.llzk new file mode 100644 index 000000000..8e44ec94a --- /dev/null +++ b/test/Transforms/PodToScalar/array_length.llzk @@ -0,0 +1,82 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @len_multi_leaf(%arr: !array.type<2 x !Pair>, %dim: index) -> index { + %len = array.len %arr, %dim : !array.type<2 x !Pair> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_multi_leaf(%[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[DIM]] : <2 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + function.def @len_empty_leaf_static_create(%dim: index) -> index { + %arr = array.new : !array.type<4 x !pod.type<[]>> + %len = array.len %arr, %dim : !array.type<4 x !pod.type<[]>> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_static_create(%[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <4 x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <4 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#map = affine_map<()[s0] -> (s0)> +module attributes {llzk.lang} { + function.def @len_empty_leaf_array(%n: index, %dim: index) -> index { + %arr = array.new{()[%n]} : !array.type<#map x !pod.type<[]>> + %len = array.len %arr, %dim : !array.type<#map x !pod.type<[]>> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_array(%[[N:[0-9a-zA-Z_\.]+]]: index, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new{()[%[[N]]]} : <#map x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#map x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + function.def @len_empty_leaf_static_arg(%arr: !array.type<4 x !pod.type<[]>>, %dim: index) -> index { + %len = array.len %arr, %dim : !array.type<4 x !pod.type<[]>> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_static_arg(%[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <4 x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <4 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#map_arg = affine_map<()[s0] -> (s0)> +module attributes {llzk.lang} { + function.def @len_empty_leaf_affine_arg( + %arr: !array.type<#map_arg x !pod.type<[]>>, %dim: index + ) -> index { + %len = array.len %arr, %dim : !array.type<#map_arg x !pod.type<[]>> + function.return %len : index + } +} +// CHECK: #map = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_affine_arg(%[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#map x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#map x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } From 2d0db6920143bfe69a24dabf24105498d0968f72 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 10:52:31 -0500 Subject: [PATCH 12/36] fix: Populate result split sizes before cloning attrs --- lib/Dialect/POD/Transforms/PodToScalarPass.cpp | 6 ++++++ .../function_result_attrs_array_of_pod.llzk | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index afa5ec206..7ff78b835 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1023,6 +1023,12 @@ class SplitPodArrayInFuncDefOp : public OpConversionPattern { SmallVector originalInputIdxToSize, originalResultIdxToSize; SmallVector newInputs = splitPodArrayType(oldTy.getInputs(), &originalInputIdxToSize); + SmallVector newResultsWithSizeInfo = + splitPodArrayType(oldTy.getResults(), &originalResultIdxToSize); + assert( + newResultsWithSizeInfo == newResults && + "expected array-of-pod type conversion to match function result attr replication" + ); SplitFunctionNameInfo inputNameInfo = collectSplitFunctionNameInfo(op.getArgumentTypes(), [&](unsigned i) { return op.getArgNameAttr(i); diff --git a/test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk b/test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk new file mode 100644 index 000000000..b3e69ddef --- /dev/null +++ b/test/Transforms/PodToScalar/function_result_attrs_array_of_pod.llzk @@ -0,0 +1,16 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@lhs: index, @rhs: !felt.type]> +!PodArray = !array.type<2 x !Pair> +module attributes {llzk.lang} { + function.def @named_array_result(%arg: !PodArray) -> (!PodArray {function.res_name = "out"}) { + function.return %arg : !PodArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @named_array_result +// CHECK-SAME: (%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>) +// CHECK-SAME: -> (!array.type<2 x index> {function.res_name = "out.lhs"}, !array.type<2 x !felt.type> {function.res_name = "out.rhs"}) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2 x index>, !array.type<2 x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: } From e5161712b980db2a22b3f02f51a929b994e6466d Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 12:28:03 -0500 Subject: [PATCH 13/36] fix: split equality constraints over POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 154 +++++++++++++++++- .../PodToScalar/constrain_array_of_pod.llzk | 91 +++++++++++ 2 files changed, 242 insertions(+), 3 deletions(-) create mode 100644 test/Transforms/PodToScalar/constrain_array_of_pod.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 7ff78b835..3d47c006c 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -21,8 +21,8 @@ /// rewriting steps. /// /// 2. Run a dialect conversion that splits arrays whose element type is a POD into parallel arrays -/// in `llzk.nondet`, `array.*`, `MemberReadOp`, `MemberWriteOp`, `FuncDefOp`, `CallOp`, and -/// `ReturnOp`. +/// in `llzk.nondet`, `array.*`, `constrain.eq`, `constrain.in`, `struct.readm`, `struct.writem`, +/// `function.def`, `function.call`, and `function.return`. /// /// 3. Run a dialect conversion that does the following: /// @@ -70,6 +70,7 @@ #include "llzk/Dialect/Bool/IR/Dialect.h" #include "llzk/Dialect/Cast/IR/Dialect.h" #include "llzk/Dialect/Constrain/IR/Dialect.h" +#include "llzk/Dialect/Constrain/IR/Ops.h" #include "llzk/Dialect/Felt/IR/Dialect.h" #include "llzk/Dialect/Function/IR/Dialect.h" #include "llzk/Dialect/Function/IR/Ops.h" @@ -1145,6 +1146,147 @@ class SplitPodArrayInCallOp : public OpConversionPattern { } }; +/// Rewrite `constrain.eq` over arrays-of-POD into one equality per parallel leaf array. +class SplitPodArrayInEmitEqualityOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(constrain::EmitEqualityOp op) { + return !containsSplittablePodArrayType(op->getOperandTypes()); + } + + LogicalResult matchAndRewrite( + constrain::EmitEqualityOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + if (adaptor.getLhs().size() != adaptor.getRhs().size()) { + return rewriter.notifyMatchFailure( + op, "expected array-of-pod equality operands to expand to the same number of leaves" + ); + } + + for (auto [lhs, rhs] : llvm::zip_equal(adaptor.getLhs(), adaptor.getRhs())) { + rewriter.create(op.getLoc(), lhs, rhs); + } + rewriter.eraseOp(op); + return success(); + } +}; + +/// Rewrite `constrain.in` over arrays-of-POD into a shared-slice witness plus leaf equalities. +/// +/// After step 2 converts an array-of-POD into parallel leaf arrays, `constrain.in` can no longer be +/// left in place because it has no built-in 1:N operand rewrite. This pattern preserves the +/// original containment semantics by: +/// +/// 1. Expanding both operands into matching POD leaves. +/// 2. Computing how many leading lhs dimensions must be selected to match the rhs rank. +/// 3. Creating one nondeterministic index per selected dimension and constraining each index to be +/// in bounds using `array.len` and `constrain.eq` on the comparison results. +/// 4. Using that same index tuple for every leaf, reading a scalar leaf with `array.read` or +/// extracting an array leaf with `array.extract`. +/// 5. Emitting one `constrain.eq` per selected lhs leaf and rhs leaf, then erasing the original +/// `constrain.in`. +/// +/// Reusing the same nondeterministic indices across all leaves is essential: it guarantees that all +/// field equalities refer to the same POD element or subarray, rather than allowing different +/// leaves to match at different positions. +class SplitPodArrayInEmitContainmentOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(constrain::EmitContainmentOp op) { + return !containsSplittablePodArrayType(op->getOperandTypes()); + } + + /// Return the split scalar or leaf-array values representing one containment operand. + static SmallVector collectContainmentLeaves( + Location loc, Value originalOperand, ValueRange convertedValues, + ConversionPatternRewriter &rewriter + ) { + if (splittablePod(originalOperand.getType())) { + SmallVector podLeaves; + processInputOperand(loc, getSingleConvertedValue(convertedValues), podLeaves, rewriter); + return podLeaves; + } + + return SmallVector(convertedValues.begin(), convertedValues.end()); + } + + LogicalResult matchAndRewrite( + constrain::EmitContainmentOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + Location loc = op.getLoc(); + ArrayType lhsTy = op.getLhs().getType(); + Type rhsTy = op.getRhs().getType(); + + size_t lhsRank = lhsTy.getDimensionSizes().size(); + size_t rhsRank = 0; + if (auto rhsArrTy = llvm::dyn_cast(rhsTy)) { + rhsRank = rhsArrTy.getDimensionSizes().size(); + } + assert(lhsRank >= rhsRank && "constrain.in verifier should reject higher-rank rhs arrays"); + size_t selectedDims = lhsRank - rhsRank; + + SmallVector lhsLeaves(adaptor.getLhs().begin(), adaptor.getLhs().end()); + SmallVector rhsLeaves = + collectContainmentLeaves(loc, op.getRhs(), adaptor.getRhs(), rewriter); + if (lhsLeaves.size() != rhsLeaves.size()) { + return rewriter.notifyMatchFailure( + op, "expected array-of-pod containment operands to expand to the same number of leaves" + ); + } + + Value shapeCarrier = adaptor.getLhs().empty() + ? materializeArrayLengthCarrier(op.getLhs(), lhsTy, loc, rewriter) + : adaptor.getLhs().front(); + Value zero = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value trueVal = rewriter.create( + loc, IntegerAttr::get(IntegerType::get(rewriter.getContext(), 1), 1) + ); + + SmallVector selectedIndices; + selectedIndices.reserve(selectedDims); + for (size_t dim = 0; dim < selectedDims; ++dim) { + Value idx = rewriter.create(loc, IndexType::get(rewriter.getContext())); + Value dimVal = rewriter.create(loc, rewriter.getIndexAttr(dim)); + Value dimLen = rewriter.create(loc, shapeCarrier, dimVal); + + Value nonNegative = rewriter.create(loc, arith::CmpIPredicate::sge, idx, zero); + rewriter.create(loc, nonNegative, trueVal); + + Value inRange = rewriter.create(loc, arith::CmpIPredicate::slt, idx, dimLen); + rewriter.create(loc, inRange, trueVal); + + selectedIndices.push_back(idx); + } + + for (auto [lhsLeaf, rhsLeaf] : llvm::zip_equal(lhsLeaves, rhsLeaves)) { + Value selectedLhs = lhsLeaf; + if (auto rhsLeafArrTy = llvm::dyn_cast(rhsLeaf.getType())) { + if (!selectedIndices.empty()) { + selectedLhs = + rewriter.create(loc, rhsLeafArrTy, lhsLeaf, selectedIndices); + } + } else { + selectedLhs = + rewriter.create(loc, rhsLeaf.getType(), lhsLeaf, selectedIndices); + } + rewriter.create(loc, selectedLhs, rhsLeaf); + } + + rewriter.eraseOp(op); + return success(); + } +}; + /// Replace `array.length` on an array-of-POD with the equivalent length of any split leaf array. class SplitPodArrayLengthOp : public OpConversionPattern { public: @@ -1331,7 +1473,9 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM SplitPodArrayNonDetOp, SplitPodArrayCreateArrayOp, SplitPodArrayReadArrayOp, SplitPodArrayWriteArrayOp, SplitPodArrayExtractArrayOp, SplitPodArrayInsertArrayOp, SplitPodArrayInFuncDefOp, SplitPodArrayInReturnOp, SplitPodArrayInCallOp, - SplitPodArrayLengthOp>(typeConverter, ctx); + SplitPodArrayInEmitEqualityOp, SplitPodArrayInEmitContainmentOp, SplitPodArrayLengthOp>( + typeConverter, ctx + ); patterns.add( typeConverter, ctx, symTables, memberRepMap ); @@ -1347,6 +1491,10 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodArrayInFuncDefOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInReturnOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInCallOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInEmitEqualityOp::legal); + target.addDynamicallyLegalOp( + SplitPodArrayInEmitContainmentOp::legal + ); target.addDynamicallyLegalOp(SplitPodArrayLengthOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInMemberReadOp::legal); diff --git a/test/Transforms/PodToScalar/constrain_array_of_pod.llzk b/test/Transforms/PodToScalar/constrain_array_of_pod.llzk new file mode 100644 index 000000000..44557bf67 --- /dev/null +++ b/test/Transforms/PodToScalar/constrain_array_of_pod.llzk @@ -0,0 +1,91 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @eq_array_pod( + %lhs: !array.type<2 x !Pair>, %rhs: !array.type<2 x !Pair> + ) attributes {function.allow_constraint} { + constrain.eq %lhs, %rhs : !array.type<2 x !Pair> + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @eq_array_pod( +// CHECK-SAME: %[[LHS_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[LHS_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, +// CHECK-SAME: %[[RHS_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[RHS_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type> +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: constrain.eq %[[LHS_X]], %[[RHS_X]] : !array.type<2 x index>, !array.type<2 x index> +// CHECK-NEXT: constrain.eq %[[LHS_Y]], %[[RHS_Y]] : !array.type<2 x !felt.type>, !array.type<2 x !felt.type> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @contains_pod( + %arr: !array.type<2 x !Pair>, %elem: !Pair + ) attributes {function.allow_constraint} { + constrain.in %arr, %elem : !array.type<2 x !Pair>, !Pair + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @contains_pod( +// CHECK-SAME: %[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>, +// CHECK-SAME: %[[ELEM_X:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[ELEM_Y:[0-9a-zA-Z_\.]+]]: !felt.type +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[TRUE:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: %[[IDX:[0-9a-zA-Z_\.]+]] = llzk.nondet : index +// CHECK-NEXT: %[[DIM0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[DIM0]] : <2 x index> +// CHECK-NEXT: %[[GE0:[0-9a-zA-Z_\.]+]] = arith.cmpi sge, %[[IDX]], %[[C0]] : index +// CHECK-NEXT: constrain.eq %[[GE0]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[LT_LEN:[0-9a-zA-Z_\.]+]] = arith.cmpi slt, %[[IDX]], %[[LEN]] : index +// CHECK-NEXT: constrain.eq %[[LT_LEN]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[SEL_X:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_X]]{{\[}}%[[IDX]]] : <2 x index>, index +// CHECK-NEXT: constrain.eq %[[SEL_X]], %[[ELEM_X]] : index, index +// CHECK-NEXT: %[[SEL_Y:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_Y]]{{\[}}%[[IDX]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: constrain.eq %[[SEL_Y]], %[[ELEM_Y]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @contains_subarray( + %arr: !array.type<2,2 x !Pair>, %sub: !array.type<2 x !Pair> + ) attributes {function.allow_constraint} { + constrain.in %arr, %sub : !array.type<2,2 x !Pair>, !array.type<2 x !Pair> + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @contains_subarray( +// CHECK-SAME: %[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x index>, +// CHECK-SAME: %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2,2 x !felt.type>, +// CHECK-SAME: %[[SUB_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[SUB_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type> +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[TRUE:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: %[[IDX:[0-9a-zA-Z_\.]+]] = llzk.nondet : index +// CHECK-NEXT: %[[DIM0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[DIM0]] : <2,2 x index> +// CHECK-NEXT: %[[GE0:[0-9a-zA-Z_\.]+]] = arith.cmpi sge, %[[IDX]], %[[C0]] : index +// CHECK-NEXT: constrain.eq %[[GE0]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[LT_LEN:[0-9a-zA-Z_\.]+]] = arith.cmpi slt, %[[IDX]], %[[LEN]] : index +// CHECK-NEXT: constrain.eq %[[LT_LEN]], %[[TRUE]] : i1, i1 +// CHECK-NEXT: %[[SEL_X:[0-9a-zA-Z_\.]+]] = array.extract %[[ARR_X]]{{\[}}%[[IDX]]] : <2,2 x index> +// CHECK-NEXT: constrain.eq %[[SEL_X]], %[[SUB_X]] : !array.type<2 x index>, !array.type<2 x index> +// CHECK-NEXT: %[[SEL_Y:[0-9a-zA-Z_\.]+]] = array.extract %[[ARR_Y]]{{\[}}%[[IDX]]] : <2,2 x !felt.type> +// CHECK-NEXT: constrain.eq %[[SEL_Y]], %[[SUB_Y]] : !array.type<2 x !felt.type>, !array.type<2 x !felt.type> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } From c5d0d6291cb6f3ae28b05082b487c6fd454987d6 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 14:27:37 -0500 Subject: [PATCH 14/36] fix: Avoid creating arrays whose elements are arrays --- include/llzk/Util/TypeHelper.h | 8 + .../POD/Transforms/PodToScalarPass.cpp | 186 ++++++++++-------- .../Polymorphic/Transforms/FlatteningPass.cpp | 14 +- lib/Util/TypeHelper.cpp | 9 + .../PodToScalar/array_leaf_in_pod_array.llzk | 44 +++++ 5 files changed, 166 insertions(+), 95 deletions(-) create mode 100644 test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk diff --git a/include/llzk/Util/TypeHelper.h b/include/llzk/Util/TypeHelper.h index eebd173ab..d76e58259 100644 --- a/include/llzk/Util/TypeHelper.h +++ b/include/llzk/Util/TypeHelper.h @@ -171,6 +171,14 @@ namespace llzk { bool isDynamic(mlir::IntegerAttr intAttr); +/// Flatten any array-valued element type into the dimensions of `outerArrTy`. +/// +/// This is used when an LLZK array logically resolves to a higher-rank array even though array +/// element types cannot themselves be arrays. The returned type keeps `outerArrTy`'s leading +/// dimensions, appends any nested dimensions from `elementType`, and uses the innermost non-array +/// element type as the final element type. +array::ArrayType flattenArrayElementType(array::ArrayType outerArrTy, mlir::Type elementType); + /// Compute the cardinality (i.e. number of scalar constraints) for an EmitEqualityOp type since the /// op can be used to constrain two same-size arrays. uint64_t computeEmitEqCardinality(mlir::Type type); diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 3d47c006c..20a93e250 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -89,6 +89,7 @@ #include "llzk/Transforms/LLZKTransformationPasses.h" #include "llzk/Transforms/SpecializedMemoryPasses.h" #include "llzk/Util/Concepts.h" +#include "llzk/Util/TypeHelper.h" #include "llzk/Util/Walk.h" #include @@ -217,7 +218,9 @@ static Type getFlattenedTypeAlongPath(Type type, ArrayRef recordChai auto elemPodTy = llvm::cast(arrTy.getElementType()); Type nextType = elemPodTy.getRecordMap().lookup(recordChain.front().getValue()); assert(nextType && "record path must exist in the POD array element type"); - return arrTy.cloneWith(getFlattenedTypeAlongPath(nextType, recordChain.drop_front())); + return flattenArrayElementType( + arrTy, getFlattenedTypeAlongPath(nextType, recordChain.drop_front()) + ); } llvm_unreachable("record path cannot continue through a non-POD leaf"); @@ -237,7 +240,7 @@ static void forEachPodLeaf(PodType podTy, SmallVectorImpl &recordCha auto elemPodTy = llvm::cast(arrTy.getElementType()); for (RecordAttr record : elemPodTy.getRecords()) { recordChain.push_back(record.getName()); - walk(arrTy.cloneWith(record.getType())); + walk(flattenArrayElementType(arrTy, record.getType())); recordChain.pop_back(); } } else { @@ -306,7 +309,7 @@ static size_t splitPodArrayTypeTo( SmallVector recordChain; size_t originalSize = collect.size(); forEachPodLeaf(podTy, recordChain, [&](RecordChain id, Type leafType) { - collect.push_back(at.cloneWith(leafType)); + collect.push_back(flattenArrayElementType(at, leafType)); if (splitIds) { splitIds->push_back(std::move(id)); } @@ -363,7 +366,7 @@ static SmallVector getSplitPodArrayRecordNameSuffixes(Type type) { /// Create a `pod.read` for one record of `podRef`. inline static ReadPodOp -genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &bldr) { +genRead(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName) { Type resultType = llvm::cast(podRef.getType()).getRecordMap().lookup(recordName.getValue()); return bldr.create(loc, resultType, podRef, recordName); @@ -371,7 +374,7 @@ genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &bldr) { /// Create a `pod.write` for one record of `podRef`. inline static WritePodOp -genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuilder &bldr) { +genWrite(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName, Value value) { return bldr.create(loc, podRef, recordName, value); } @@ -426,7 +429,7 @@ static SmallVector flattenConvertedValues(RangeOfRanges ranges) { } /// Generate `arith.constant` indices for one static array element position. -static SmallVector genArrayIndexConstants(ArrayAttr index, Location loc, OpBuilder &bldr) { +static SmallVector genArrayIndexConstants(OpBuilder &bldr, Location loc, ArrayAttr index) { SmallVector indices; for (Attribute attr : index) { assert(llvm::isa(attr) && "array index must be an integer attribute"); @@ -435,34 +438,64 @@ static SmallVector genArrayIndexConstants(ArrayAttr index, Location loc, return indices; } -/// Create an `array.read` for one concrete element or subarray. -inline static ReadArrayOp -genArrayRead(Location loc, Value arrayRef, ArrayAttr index, OpBuilder &bldr) { +/// Return the type produced by selecting `numIndices` leading dimensions from `arrTy`. +static Type getArraySelectionType(ArrayType arrTy, size_t numIndices) { + assert(numIndices <= arrTy.getDimensionSizes().size() && "cannot select past the array rank"); + if (numIndices == arrTy.getDimensionSizes().size()) { + return arrTy.getElementType(); + } + return ArrayType::get(arrTy.getElementType(), arrTy.getDimensionSizes().drop_front(numIndices)); +} + +/// Create an `array.read` or `array.extract` for one concrete element or subarray. +static Value genArrayRead(OpBuilder &bldr, Location loc, Value arrayRef, ArrayRef indices) { Type t = arrayRef.getType(); - assert(llvm::isa(t) && "array.read must target an array type"); - return bldr.create( - loc, llvm::cast(t).getElementType(), arrayRef, - genArrayIndexConstants(index, loc, bldr) + assert(llvm::isa(t) && "array access must target an array type"); + ArrayType arrTy = llvm::cast(t); + if (indices.size() == arrTy.getDimensionSizes().size()) { + return bldr.create(loc, arrTy.getElementType(), arrayRef, indices); + } + return bldr.create( + loc, llvm::cast(getArraySelectionType(arrTy, indices.size())), arrayRef, indices ); } -/// Create an `array.write` for one concrete element or subarray. -inline static WriteArrayOp -genArrayWrite(Location loc, Value arrayRef, ArrayAttr index, Value value, OpBuilder &bldr) { - return bldr.create(loc, arrayRef, genArrayIndexConstants(index, loc, bldr), value); +inline static Value genArrayRead(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index) { + SmallVector indices = genArrayIndexConstants(bldr, loc, index); + return genArrayRead(bldr, loc, arrayRef, indices); +} + +/// Create an `array.write` or `array.insert` for one concrete element or subarray. +static void +genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayRef indices, Value value) { + Type t = arrayRef.getType(); + assert(llvm::isa(t) && "array access must target an array type"); + ArrayType arrTy = llvm::cast(t); + if (indices.size() == arrTy.getDimensionSizes().size()) { + bldr.create(loc, arrayRef, indices, value); + return; + } + assert(llvm::isa(value.getType()) && "subarray insertion requires an array value"); + bldr.create(loc, arrayRef, indices, value); +} + +inline static void +genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index, Value value) { + SmallVector indices = genArrayIndexConstants(bldr, loc, index); + genArrayWrite(bldr, loc, arrayRef, indices, value); } /// Read one flattened POD leaf, including leaves that live inside an array-of-POD record. static Value -genReadAlongPath(Location loc, Value value, ArrayRef recordChain, OpBuilder &bldr) { +genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef recordChain) { if (recordChain.empty()) { return value; } Type valueType = value.getType(); if (llvm::isa(valueType)) { - Value nextValue = genRead(loc, value, recordChain.front(), bldr); - return genReadAlongPath(loc, nextValue, recordChain.drop_front(), bldr); + Value nextValue = genRead(bldr, loc, value, recordChain.front()); + return genReadAlongPath(bldr, loc, nextValue, recordChain.drop_front()); } if (ArrayType arrTy = splittablePodArray(valueType)) { @@ -473,9 +506,9 @@ genReadAlongPath(Location loc, Value value, ArrayRef recordChain, Op Value splitArray = bldr.create(loc, splitArrTy); for (ArrayAttr index : *subIndices) { - Value element = genArrayRead(loc, value, index, bldr); - Value leafValue = genReadAlongPath(loc, element, recordChain, bldr); - genArrayWrite(loc, splitArray, index, leafValue, bldr); + Value element = genArrayRead(bldr, loc, value, index); + Value leafValue = genReadAlongPath(bldr, loc, element, recordChain); + genArrayWrite(bldr, loc, splitArray, index, leafValue); } return splitArray; } @@ -485,22 +518,22 @@ genReadAlongPath(Location loc, Value value, ArrayRef recordChain, Op /// Read a flattened POD leaf by following each record name in `recordChain`. inline static Value -genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &bldr) { - return genReadAlongPath(loc, podRef, ArrayRef(recordChain.nameList), bldr); +genReadAlongPath(OpBuilder &bldr, Location loc, Value podRef, RecordChain recordChain) { + return genReadAlongPath(bldr, loc, podRef, ArrayRef(recordChain.nameList)); } /// Reconstruct a POD record from the leaf values collected while splitting nested accesses. static Value rebuildFlattenedPodRecord( - Location loc, Type recordType, SmallVectorImpl &recordChain, - const DenseMap &leafValues, OpBuilder &bldr + OpBuilder &bldr, Location loc, Type recordType, SmallVectorImpl &recordChain, + const DenseMap &leafValues ) { if (PodType nestedPodTy = dyn_cast(recordType)) { NewPodOp nestedPod = bldr.create(loc, nestedPodTy); for (RecordAttr record : nestedPodTy.getRecords()) { recordChain.push_back(record.getName()); Value recordValue = - rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, bldr); - genWrite(loc, nestedPod, record.getName(), recordValue, bldr); + rebuildFlattenedPodRecord(bldr, loc, record.getType(), recordChain, leafValues); + genWrite(bldr, loc, nestedPod, record.getName(), recordValue); recordChain.pop_back(); } return nestedPod; @@ -521,7 +554,7 @@ static Value rebuildFlattenedPodRecord( llvm::append_range(fullChain, id.nameList); auto it = leafValues.find(RecordChain(fullChain)); assert(it != leafValues.end() && "missing flattened POD array leaf value"); - elementLeafValues[id] = genArrayRead(loc, it->second, index, bldr); + elementLeafValues[id] = genArrayRead(bldr, loc, it->second, index); }); NewPodOp elementPod = bldr.create(loc, elemPodTy); @@ -529,11 +562,11 @@ static Value rebuildFlattenedPodRecord( for (RecordAttr record : elemPodTy.getRecords()) { nestedChain.push_back(record.getName()); Value recordValue = - rebuildFlattenedPodRecord(loc, record.getType(), nestedChain, elementLeafValues, bldr); - genWrite(loc, elementPod, record.getName(), recordValue, bldr); + rebuildFlattenedPodRecord(bldr, loc, record.getType(), nestedChain, elementLeafValues); + genWrite(bldr, loc, elementPod, record.getName(), recordValue); nestedChain.pop_back(); } - genArrayWrite(loc, rebuiltArray, index, elementPod, bldr); + genArrayWrite(bldr, loc, rebuiltArray, index, elementPod); } return rebuiltArray; } @@ -543,21 +576,6 @@ static Value rebuildFlattenedPodRecord( return it->second; } -/// Populate a POD value from its flattened leaf values. -static void populateFlattenedPodValue( - Location loc, Value podValue, PodType podTy, const DenseMap &leafValues, - OpBuilder &bldr -) { - SmallVector recordChain; - for (RecordAttr record : podTy.getRecords()) { - recordChain.push_back(record.getName()); - Value recordValue = - rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, bldr); - genWrite(loc, podValue, record.getName(), recordValue, bldr); - recordChain.pop_back(); - } -} - using VirtualPodLeafMap = DenseMap; using VirtualPodValueMap = DenseMap; @@ -583,8 +601,16 @@ orderedVirtualPodLeafValues(PodType podTy, const VirtualPodLeafMap &leafValues) /// Materialize the tracked contents of a virtual POD into concrete `pod.write` operations. inline static void -materializeVirtualPod(NewPodOp pod, const VirtualPodLeafMap &leafValues, OpBuilder &bldr) { - populateFlattenedPodValue(pod.getLoc(), pod, pod.getType(), leafValues, bldr); +materializeVirtualPod(OpBuilder &bldr, NewPodOp pod, const VirtualPodLeafMap &leafValues) { + Location loc = pod.getLoc(); + SmallVector recordChain; + for (RecordAttr record : pod.getType().getRecords()) { + recordChain.push_back(record.getName()); + Value recordValue = + rebuildFlattenedPodRecord(bldr, loc, record.getType(), recordChain, leafValues); + genWrite(bldr, loc, pod, record.getName(), recordValue); + recordChain.pop_back(); + } } /// Return `true` iff a read from a virtual POD can be resolved without materializing it. @@ -628,7 +654,7 @@ static void processInputOperand( } SmallVector recordChain; forEachPodLeaf(pt, recordChain, [&](RecordChain id, Type) { - newOperands.push_back(genReadAlongPath(loc, operand, id, rewriter)); + newOperands.push_back(genReadAlongPath(rewriter, loc, operand, id)); }); } else { newOperands.push_back(operand); @@ -883,7 +909,7 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { splitElements.reserve(adaptor.getElements().size()); for (ValueRange elementRange : adaptor.getElements()) { Value element = getSingleConvertedValue(elementRange); - splitElements.push_back(genReadAlongPath(op.getLoc(), element, id, rewriter)); + splitElements.push_back(genReadAlongPath(rewriter, op.getLoc(), element, id)); } replacements.push_back(rewriter.create( op.getLoc(), llvm::cast(splitType), splitElements @@ -912,7 +938,7 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { } }; -/// Split `array.read` from an array-of-POD into scalar leaf reads plus local POD reconstruction. +/// Split `array.read` from an array-of-POD into leaf reads plus local POD reconstruction. class SplitPodArrayReadArrayOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -934,22 +960,18 @@ class SplitPodArrayReadArrayOp : public OpConversionPattern { SmallVector indices = flattenConvertedValues(adaptor.getIndices()); NewPodOp pod = rewriter.create(op.getLoc(), podTy); DenseMap leafValues; - for (auto [id, splitArrRange, splitType] : - llvm::zip_equal(splitIds, adaptor.getArrRef(), splitTypes)) { - auto splitArrTy = llvm::cast(splitType); - Value scalarRead = rewriter.create( - op.getLoc(), splitArrTy.getElementType(), getSingleConvertedValue(splitArrRange), indices - ); - leafValues[id] = scalarRead; + for (auto [id, splitArrRange] : llvm::zip_equal(splitIds, adaptor.getArrRef())) { + leafValues[id] = + genArrayRead(rewriter, op.getLoc(), getSingleConvertedValue(splitArrRange), indices); } SmallVector recordChain; for (RecordAttr record : podTy.getRecords()) { recordChain.push_back(record.getName()); Value recordValue = rebuildFlattenedPodRecord( - op.getLoc(), record.getType(), recordChain, leafValues, rewriter + rewriter, op.getLoc(), record.getType(), recordChain, leafValues ); - genWrite(op.getLoc(), pod, record.getName(), recordValue, rewriter); + genWrite(rewriter, op.getLoc(), pod, record.getName(), recordValue); recordChain.pop_back(); } rewriter.replaceOp(op, pod); @@ -979,9 +1001,9 @@ class SplitPodArrayWriteArrayOp : public OpConversionPattern { Value podValue = getSingleConvertedValue(adaptor.getRvalue()); for (auto [id, splitArrRange, splitType] : llvm::zip_equal(splitIds, adaptor.getArrRef(), splitTypes)) { - Value leafValue = genReadAlongPath(op.getLoc(), podValue, id, rewriter); - rewriter.create( - op.getLoc(), getSingleConvertedValue(splitArrRange), indices, leafValue + Value leafValue = genReadAlongPath(rewriter, op.getLoc(), podValue, id); + genArrayWrite( + rewriter, op.getLoc(), getSingleConvertedValue(splitArrRange), indices, leafValue ); } rewriter.eraseOp(op); @@ -1753,7 +1775,7 @@ class SplitPodInMemberWriteOp : public OpConversionPattern { for (auto [id, newMember] : idToMember) { Value scalarValue = virtualLeafValues ? virtualLeafValues->at(id) - : genReadAlongPath(op.getLoc(), adaptor.getVal(), id, rewriter); + : genReadAlongPath(rewriter, op.getLoc(), adaptor.getVal(), id); rewriter.create( op.getLoc(), adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarValue ); @@ -1887,7 +1909,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM } if (auto newPod = llvm::dyn_cast(podValue.getDefiningOp())) { builder.setInsertionPointAfter(newPod); - materializeVirtualPod(newPod, leafValues, builder); + materializeVirtualPod(builder, newPod, leafValues); } } return success(); @@ -2055,7 +2077,7 @@ class ReplaceIfReadPattern final : public OpRewritePattern { rewriter.setInsertionPoint(ifOp); rewriter.replaceOp( - readOp, genRead(readOp.getLoc(), readOp.getPodRef(), readOp.getRecordNameAttr(), rewriter) + readOp, genRead(rewriter, readOp.getLoc(), readOp.getPodRef(), readOp.getRecordNameAttr()) .getResult() ); return success(); @@ -2260,8 +2282,8 @@ moveBranchWithoutLiftedWrites(Block *srcBlock, Block &destBlock, ArrayRef slots, - bool isThenBlock, OpBuilder &builder + OpBuilder &bldr, Location loc, Block &block, ValueRange priorYieldValues, + ArrayRef slots, bool isThenBlock ) { SmallVector yieldValues = llvm::to_vector(priorYieldValues); llvm::append_range(yieldValues, llvm::map_range(slots, [isThenBlock](const IfWriteSlot &slot) { @@ -2269,8 +2291,8 @@ static void appendYield( return writeOp ? writeOp.getValue() : slot.incomingValue; })); - builder.setInsertionPointToEnd(&block); - builder.create(loc, yieldValues); + bldr.setInsertionPointToEnd(&block); + bldr.create(loc, yieldValues); } /// One POD record whose value is carried across an SCF loop boundary as an SSA scalar. @@ -2431,7 +2453,7 @@ class LiftPodWritesFromIfBlocksPattern final : public OpRewritePattern resultTypes = llvm::to_vector(ifOp.getResultTypes()); @@ -2458,15 +2480,15 @@ class LiftPodWritesFromIfBlocksPattern final : public OpRewritePattern newInitArgs = llvm::to_vector(forOp.getInitArgs()); rewriter.setInsertionPoint(forOp); for (const LoopPodSlot &slot : slots) { - newInitArgs.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult()); + newInitArgs.push_back(genRead(rewriter, loc, slot.podRef, slot.recordName).getResult()); } auto newFor = rewriter.create( @@ -2551,7 +2573,7 @@ class LiftPodAccessesFromForLoopPattern final : public OpRewritePattern newResultTypes = llvm::to_vector(whileOp.getResultTypes()); rewriter.setInsertionPoint(whileOp); for (const LoopPodSlot &slot : slots) { - newInits.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult()); + newInits.push_back(genRead(rewriter, loc, slot.podRef, slot.recordName).getResult()); newResultTypes.push_back(slot.type); } @@ -2694,8 +2716,8 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern mergedDims(inputTy.getDimensionSizes()); - while (ArrayType nestedArrTy = llvm::dyn_cast(convertedElemTy)) { - llvm::append_range(mergedDims, nestedArrTy.getDimensionSizes()); - convertedElemTy = nestedArrTy.getElementType(); - } - return ArrayType::get(convertedElemTy, mergedDims); -} - /// TypeConverter for function instantiation that replaces TypeVarType and symbolic /// ArrayType/StructType parameters with their concrete values determined by unification. class FuncInstTypeConverter : public TypeConverter { @@ -1020,7 +1008,7 @@ class FuncInstTypeConverter : public TypeConverter { if (!changed && newElemTy == inputTy.getElementType()) { return inputTy; } - return flattenInstantiatedArrayType( + return flattenArrayElementType( inputTy.cloneWith(inputTy.getElementType(), updated), newElemTy ); }); diff --git a/lib/Util/TypeHelper.cpp b/lib/Util/TypeHelper.cpp index 2b7c5fa5d..f511c86e0 100644 --- a/lib/Util/TypeHelper.cpp +++ b/lib/Util/TypeHelper.cpp @@ -564,6 +564,15 @@ bool hasAffineMapAttr(Type type) { bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); } +ArrayType flattenArrayElementType(ArrayType outerArrTy, Type elementType) { + SmallVector mergedDims(outerArrTy.getDimensionSizes()); + while (ArrayType nestedArrTy = llvm::dyn_cast(elementType)) { + llvm::append_range(mergedDims, nestedArrTy.getDimensionSizes()); + elementType = nestedArrTy.getElementType(); + } + return ArrayType::get(elementType, mergedDims); +} + uint64_t computeEmitEqCardinality(Type type) { struct Impl : LLZKTypeSwitch { uint64_t caseBool(IntegerType) { return 1; } diff --git a/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk b/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk new file mode 100644 index 000000000..cb7ffaae1 --- /dev/null +++ b/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk @@ -0,0 +1,44 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Elem = !pod.type<[@vals: !array.type<3 x index>, @tag: index]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @pass_through(%arg: !ElemArray) -> !ElemArray { + function.return %arg : !ElemArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @pass_through( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,3 x index>, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x index> +// CHECK-SAME: ) -> (!array.type<2,3 x index>, !array.type<2 x index>) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2,3 x index>, !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Elem = !pod.type<[@vals: !array.type<3 x index>, @tag: index]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @update_first(%arr: !ElemArray, %newVals: !array.type<3 x index>) -> !ElemArray { + %c0 = arith.constant 0 : index + %elem = array.read %arr[%c0] : !ElemArray, !Elem + pod.write %elem[@vals] = %newVals : !Elem, !array.type<3 x index> + array.write %arr[%c0] = %elem : !ElemArray, !Elem + function.return %arr : !ElemArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @update_first( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,3 x index>, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]]: !array.type<3 x index> +// CHECK-SAME: ) -> (!array.type<2,3 x index>, !array.type<2 x index>) { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_0]]{{\[}}%[[VAL_3]]] : <2,3 x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_1]]{{\[}}%[[VAL_3]]] : <2 x index>, index +// CHECK-NEXT: array.insert %[[VAL_0]]{{\[}}%[[VAL_3]]] = %[[VAL_2]] : <2,3 x index>, <3 x index> +// CHECK-NEXT: array.write %[[VAL_1]]{{\[}}%[[VAL_3]]] = %[[VAL_5]] : <2 x index>, index +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2,3 x index>, !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } From 34bd6e0edf83fb0df17540a4dfbb5a04539a78f6 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 14:39:30 -0500 Subject: [PATCH 15/36] fix dangling ref --- lib/Dialect/POD/Transforms/PodToScalarPass.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 20a93e250..5b5b202aa 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -603,8 +603,9 @@ orderedVirtualPodLeafValues(PodType podTy, const VirtualPodLeafMap &leafValues) inline static void materializeVirtualPod(OpBuilder &bldr, NewPodOp pod, const VirtualPodLeafMap &leafValues) { Location loc = pod.getLoc(); + PodType podTy = pod.getType(); SmallVector recordChain; - for (RecordAttr record : pod.getType().getRecords()) { + for (RecordAttr record : podTy.getRecords()) { recordChain.push_back(record.getName()); Value recordValue = rebuildFlattenedPodRecord(bldr, loc, record.getType(), recordChain, leafValues); From 45632675f96058e0b5676a8d5a563ca1a6b438a9 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 16:05:23 -0500 Subject: [PATCH 16/36] additional regression test --- ...tion_signatures_with_nested_pod_array.llzk | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk diff --git a/test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk b/test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk new file mode 100644 index 000000000..8e140cff2 --- /dev/null +++ b/test/Transforms/PodToScalar/function_signatures_with_nested_pod_array.llzk @@ -0,0 +1,28 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Inner = !pod.type<[@x: index]> +!Outer = !pod.type<[@a: !array.type<2 x !Inner>, @b: index]> +module attributes {llzk.lang} { + function.def @id( + %arg: !Outer {function.arg_name = "arg"} + ) -> (!Outer {function.res_name = "out"}) { + function.return %arg : !Outer + } + + function.def @main(%arg: !Outer) -> !Outer { + %res = function.call @id(%arg) : (!Outer) -> !Outer + function.return %res : !Outer + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @id( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2 x index> {function.arg_name = "arg.a.x"}, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: index {function.arg_name = "arg.b"} +// CHECK-SAME: ) -> (!array.type<2 x index> {function.res_name = "out.a.x"}, index {function.res_name = "out.b"}) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<2 x index>, index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @main(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: index) -> (!array.type<2 x index>, index) { +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]]:2 = function.call @id(%[[VAL_2]], %[[VAL_3]]) : (!array.type<2 x index>, index) -> (!array.type<2 x index>, index) +// CHECK-NEXT: function.return %[[VAL_4]]#0, %[[VAL_4]]#1 : !array.type<2 x index>, index +// CHECK-NEXT: } +// CHECK-NEXT: } From fae79594acf07f1c793d9b5282e8c6d82e68ad31 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Tue, 23 Jun 2026 16:51:30 -0500 Subject: [PATCH 17/36] clang-tidy cleanup --- include/llzk/Util/Walk.h | 2 +- .../POD/Transforms/PodToScalarPass.cpp | 32 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/include/llzk/Util/Walk.h b/include/llzk/Util/Walk.h index 7f02f341c..c1a0636ba 100644 --- a/include/llzk/Util/Walk.h +++ b/include/llzk/Util/Walk.h @@ -30,7 +30,7 @@ inline static bool walkContainsMatch(R &root, llvm::function_ref inline static bool walkContains(R &root) { - return root.walk([](MatchType t) { return mlir::WalkResult::interrupt(); }).wasInterrupted(); + return root.walk([](MatchType) { return mlir::WalkResult::interrupt(); }).wasInterrupted(); } /// Collect all walked operations of type `MatchType` rooted at `root` into a diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 5b5b202aa..55bd0903c 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -257,7 +257,7 @@ size_t splitPodTypeTo(Type t, SmallVector &collect) { if (PodType pt = splittablePod(t)) { SmallVector recordChain; size_t originalSize = collect.size(); - forEachPodLeaf(pt, recordChain, [&collect](RecordChain, Type leafType) { + forEachPodLeaf(pt, recordChain, [&collect](const RecordChain &, Type leafType) { collect.push_back(leafType); }); return collect.size() - originalSize; @@ -518,7 +518,7 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef orderedVirtualPodLeafValues(PodType podTy, const VirtualPodLeafMap &leafValues) { SmallVector orderedValues; SmallVector recordChain; - forEachPodLeaf(podTy, recordChain, [&leafValues, &orderedValues](RecordChain id, Type) { + forEachPodLeaf(podTy, recordChain, [&leafValues, &orderedValues](const RecordChain &id, Type) { auto it = leafValues.find(id); assert(it != leafValues.end() && "missing virtual POD leaf value"); orderedValues.push_back(it->second); @@ -628,7 +628,7 @@ static SmallVector getSplitRecordNameSuffixes(Type type) { SmallVector suffixes; if (PodType pt = splittablePod(type)) { SmallVector recordChain; - forEachPodLeaf(pt, recordChain, [&suffixes](RecordChain id, Type) { + forEachPodLeaf(pt, recordChain, [&suffixes](const RecordChain &id, Type) { std::string suffix; llvm::raw_string_ostream os(suffix); for (StringAttr recordName : id.nameList) { @@ -654,7 +654,7 @@ static void processInputOperand( } } SmallVector recordChain; - forEachPodLeaf(pt, recordChain, [&](RecordChain id, Type) { + forEachPodLeaf(pt, recordChain, [&](const RecordChain &id, Type) { newOperands.push_back(genReadAlongPath(rewriter, loc, operand, id)); }); } else { @@ -741,7 +741,7 @@ static void flattenPodMemberIntoLeaves( LocalMemberReplacementMap &localRepMapRef, SymbolTable &structSymbolTable, ConversionPatternRewriter &rewriter ) { - forEachPodLeaf(podTy, recordChain, [&](RecordChain id, Type ty) { + forEachPodLeaf(podTy, recordChain, [&](const RecordChain &id, Type ty) { StringAttr name = getFlattenedMemberName( originalMember.getContext(), originalMember.getSymNameAttr(), id.nameList ); @@ -1026,7 +1026,7 @@ class SplitPodArrayInFuncDefOp : public OpConversionPattern { LogicalResult matchAndRewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { - auto *tyConv = getTypeConverter(); + const auto *tyConv = getTypeConverter(); assert(tyConv && "expected pod-array type converter"); FunctionType oldTy = op.getFunctionType(); @@ -1120,7 +1120,7 @@ class SplitPodArrayInCallOp : public OpConversionPattern { LogicalResult matchAndRewrite( CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter ) const override { - auto *tyConv = getTypeConverter(); + const auto *tyConv = getTypeConverter(); assert(tyConv && "expected pod-array type converter"); SmallVector newResultTypes; @@ -1279,7 +1279,9 @@ class SplitPodArrayInEmitContainmentOp : public OpConversionPattern(loc, IndexType::get(rewriter.getContext())); - Value dimVal = rewriter.create(loc, rewriter.getIndexAttr(dim)); + Value dimVal = rewriter.create( + loc, rewriter.getIndexAttr(llzk::checkedCast(dim)) + ); Value dimLen = rewriter.create(loc, shapeCarrier, dimVal); Value nonNegative = rewriter.create(loc, arith::CmpIPredicate::sge, idx, zero); @@ -1624,7 +1626,7 @@ class SplitPodInFuncDefOp : public OpConversionPattern { DenseMap leafValues; SmallVector recordChain; - forEachPodLeaf(pt, recordChain, [&](RecordChain id, Type leafType) { + forEachPodLeaf(pt, recordChain, [&](const RecordChain &id, Type leafType) { BlockArgument newArg = entryBlock.insertArgument(i, leafType, loc); leafValues[id] = newArg; ++i; @@ -1697,7 +1699,7 @@ static CallOp newCallOpWithSplitResults( Location loc = oldVal.getLoc(); DenseMap leafValues; SmallVector recordChain; - forEachPodLeaf(pt, recordChain, [&leafValues, &newResults](RecordChain id, Type) { + forEachPodLeaf(pt, recordChain, [&leafValues, &newResults](const RecordChain &id, Type) { leafValues[id] = *newResults; ++newResults; }); @@ -1773,7 +1775,7 @@ class SplitPodInMemberWriteOp : public OpConversionPattern { const VirtualPodLeafMap *virtualLeafValues = lookupVirtualPodLeafMap(adaptor.getVal(), virtualPods); - for (auto [id, newMember] : idToMember) { + for (const auto &[id, newMember] : idToMember) { Value scalarValue = virtualLeafValues ? virtualLeafValues->at(id) : genReadAlongPath(rewriter, op.getLoc(), adaptor.getVal(), id); @@ -1815,7 +1817,7 @@ class SplitPodInMemberReadOp : public OpConversionPattern { repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); VirtualPodLeafMap leafValues; - for (auto [id, newMember] : idToMember) { + for (const auto &[id, newMember] : idToMember) { leafValues[id] = rewriter.create( op.getLoc(), newMember.second, adaptor.getComponent(), newMember.first ); @@ -2315,7 +2317,7 @@ struct LoopPodSlot { /// Return the tracked loop slot for `podRef.recordName`, or null if not found. static LoopPodSlot * lookupLoopSlot(SmallVectorImpl &slots, Value podRef, StringAttr recordName) { - auto it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { + auto *it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { return slot.matches(podRef, recordName); }); return it == slots.end() ? nullptr : &*it; @@ -2323,7 +2325,7 @@ lookupLoopSlot(SmallVectorImpl &slots, Value podRef, StringAttr rec /// Return whether a loop slot is tracked for `podRef.recordName`. static bool hasLoopSlot(ArrayRef slots, Value podRef, StringAttr recordName) { - auto it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { + const auto *it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) { return slot.matches(podRef, recordName); }); return it != slots.end(); From 26c1cc527555286ba4472fef99e25c2075776316 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 24 Jun 2026 08:33:46 -0500 Subject: [PATCH 18/36] fix: Flatten array-valued leaves for initialized array.new --- .../POD/Transforms/PodToScalarPass.cpp | 127 +++++++++++++++++- .../PodToScalar/array_leaf_in_pod_array.llzk | 34 +++++ 2 files changed, 154 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 55bd0903c..1bc179508 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -428,6 +428,18 @@ static SmallVector flattenConvertedValues(RangeOfRanges ranges) { return values; } +/// Create an array value that callers can fully initialize via explicit writes or inserts. +/// +/// Use `llzk.nondet` as the base when affine-map dimensions are present because `array.new` +/// cannot carry both inline elements and affine-map instantiation operands. +inline static Value createWritableArrayValue(OpBuilder &bldr, Location loc, ArrayType arrTy) { + if (hasAffineMapAttr(arrTy)) { + return bldr.create(loc, arrTy); + } else { + return bldr.create(loc, arrTy); + } +} + /// Generate `arith.constant` indices for one static array element position. static SmallVector genArrayIndexConstants(OpBuilder &bldr, Location loc, ArrayAttr index) { SmallVector indices; @@ -905,16 +917,36 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { replacements.reserve(splitTypes.size()); DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr(); if (isNullOrEmpty(numDimsPerMap)) { + if (adaptor.getElements().empty()) { + for (Type splitType : splitTypes) { + replacements.push_back( + rewriter.create(op.getLoc(), llvm::cast(splitType)) + ); + } + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } + + auto elementIndices = arrTy.getSubelementIndices(); + assert(elementIndices && "array.new with explicit elements requires a static array shape"); + assert( + elementIndices->size() == adaptor.getElements().size() && + "array.new element count must match the outer array cardinality" + ); + + // Inline initializers are linearized only across the original outer array dimensions. When + // a flattened POD leaf is itself an array, populate the rewritten split array one outer + // element at a time so each leaf array becomes a subarray insert rather than a malformed + // inline operand to the flattened `array.new`. for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { - SmallVector splitElements; - splitElements.reserve(adaptor.getElements().size()); - for (ValueRange elementRange : adaptor.getElements()) { + Value splitArray = + createWritableArrayValue(rewriter, op.getLoc(), llvm::cast(splitType)); + for (auto [index, elementRange] : llvm::zip_equal(*elementIndices, adaptor.getElements())) { Value element = getSingleConvertedValue(elementRange); - splitElements.push_back(genReadAlongPath(rewriter, op.getLoc(), element, id)); + Value leafValue = genReadAlongPath(rewriter, op.getLoc(), element, id); + genArrayWrite(rewriter, op.getLoc(), splitArray, index, leafValue); } - replacements.push_back(rewriter.create( - op.getLoc(), llvm::cast(splitType), splitElements - )); + replacements.push_back(splitArray); } } else { SmallVector> mapOperandStorage; @@ -1556,6 +1588,85 @@ class SplitInitFromNewPodOp : public OpConversionPattern { } }; +/// Rewrite `array.new` when explicit elements are PODs or flattened leaf arrays. +/// +/// This occurs after the array-of-POD stage has already converted the result type away from +/// `!array.type<... x !pod.type<...>>`, but before the POD operands themselves have been fully +/// scalarized. Rebuild the destination array explicitly so leaf arrays become subarray inserts +/// rather than invalid inline operands to the flattened `array.new`. +class SplitPodElementCreateArrayOp : public OpConversionPattern { + const VirtualPodValueMap &virtualPods; + +public: + SplitPodElementCreateArrayOp(MLIRContext *ctx, const VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + static bool legal(CreateArrayOp op) { + return !llvm::any_of(op.getElements().getTypes(), [](Type type) { + return splittablePod(type) || llvm::isa(type); + }); + } + + LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); } + + void + rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector leafElements; + leafElements.reserve(adaptor.getElements().size()); + + Type leafType; + for (Value element : adaptor.getElements()) { + SmallVector flattenedValues; + if (splittablePod(element.getType())) { + processInputOperand(op.getLoc(), element, flattenedValues, rewriter, &virtualPods); + } else { + flattenedValues.push_back(element); + } + + assert( + flattenedValues.size() == 1 && + "array.new elements should already have been split to a single flattened leaf" + ); + if (!leafType) { + leafType = flattenedValues.front().getType(); + } else { + assert( + leafType == flattenedValues.front().getType() && "array.new elements must stay uniform" + ); + } + leafElements.push_back(flattenedValues.front()); + } + + size_t leafRank = 0; + if (auto leafArrTy = llvm::dyn_cast_if_present(leafType)) { + leafRank = leafArrTy.getDimensionSizes().size(); + } + ArrayType arrTy = op.getType(); + assert( + arrTy.getDimensionSizes().size() >= leafRank && "flattened leaf rank exceeds array rank" + ); + size_t outerRank = arrTy.getDimensionSizes().size() - leafRank; + assert(outerRank > 0 && "array.new elements must populate at least one outer array dimension"); + + ArrayType outerIndexTy = + ArrayType::get(arrTy.getElementType(), arrTy.getDimensionSizes().take_front(outerRank)); + auto elementIndices = outerIndexTy.getSubelementIndices(); + assert( + elementIndices && "array.new with explicit POD elements requires static outer dimensions" + ); + assert( + elementIndices->size() == leafElements.size() && + "array.new element count must match the outer array cardinality" + ); + + Value rebuiltArray = createWritableArrayValue(rewriter, op.getLoc(), arrTy); + for (auto [index, leafValue] : llvm::zip_equal(*elementIndices, leafElements)) { + genArrayWrite(rewriter, op.getLoc(), rebuiltArray, index, leafValue); + } + rewriter.replaceOp(op, rebuiltArray); + } +}; + /// Rewrite pod-typed function signatures to pass one scalar per POD record instead. /// /// Each pod argument is expanded into one scalar argument per record, and each pod result is @@ -1882,6 +1993,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM RewritePatternSet patterns(ctx); patterns.add(ctx); + patterns.add(ctx, virtualPods); patterns.add(ctx, virtualPods); patterns.add( ctx, symTables, memberRepMap, virtualPods @@ -1891,6 +2003,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM ConversionTarget target(*ctx); baseTargetSetup(target); target.addDynamicallyLegalOp(SplitInitFromNewPodOp::legal); + target.addDynamicallyLegalOp(SplitPodElementCreateArrayOp::legal); target.addDynamicallyLegalOp(SplitPodInFuncDefOp::legal); target.addDynamicallyLegalOp(SplitPodInReturnOp::legal); target.addDynamicallyLegalOp(SplitPodInCallOp::legal); diff --git a/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk b/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk index cb7ffaae1..93415f4ba 100644 --- a/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk +++ b/test/Transforms/PodToScalar/array_leaf_in_pod_array.llzk @@ -17,6 +17,40 @@ module attributes {llzk.lang} { // CHECK-NEXT: } // ----- +!Elem = !pod.type<[@vals: !array.type<3 x index>]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @pack( + %a0: index, %a1: index, %a2: index, %b0: index, %b1: index, %b2: index + ) -> !ElemArray { + %vals0 = array.new %a0, %a1, %a2 : !array.type<3 x index> + %vals1 = array.new %b0, %b1, %b2 : !array.type<3 x index> + %lhs = pod.new : !Elem + pod.write %lhs[@vals] = %vals0 : !Elem, !array.type<3 x index> + %rhs = pod.new : !Elem + pod.write %rhs[@vals] = %vals1 : !Elem, !array.type<3 x index> + %arr = array.new %lhs, %rhs : !ElemArray + function.return %arr : !ElemArray + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @pack( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_2:[0-9a-zA-Z_\.]+]]: index, %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, %[[VAL_5:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> !array.type<2,3 x index> { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = array.new %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : <3 x index> +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : <3 x index> +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = array.new : <2,3 x index> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.insert %[[VAL_8]]{{\[}}%[[VAL_9]]] = %[[VAL_6]] : <2,3 x index>, <3 x index> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.insert %[[VAL_8]]{{\[}}%[[VAL_10]]] = %[[VAL_7]] : <2,3 x index>, <3 x index> +// CHECK-NEXT: function.return %[[VAL_8]] : !array.type<2,3 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + !Elem = !pod.type<[@vals: !array.type<3 x index>, @tag: index]> !ElemArray = !array.type<2 x !Elem> module attributes {llzk.lang} { From 69cfd98972bb744b0918d1ef146da701585498dd Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 24 Jun 2026 09:22:37 -0500 Subject: [PATCH 19/36] fix: Convert pod.read array fields before splitting array reads --- .../POD/Transforms/PodToScalarPass.cpp | 124 +++++++++++++++++- .../array_read_from_pod_field.llzk | 22 ++++ 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 test/Transforms/PodToScalar/array_read_from_pod_field.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 1bc179508..da54f8038 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -635,6 +635,12 @@ static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &vir return llvm::isa(recType) || !splittablePodArray(recType); } +/// Return `true` iff step 2 should defer splitting this array read until POD-aware rewriting. +static bool shouldDeferPodArrayReadToStep3(ReadArrayOp op) { + return splittablePodArray(op.getArrRefType()) && + llvm::isa_and_present(op.getArrRef().getDefiningOp()); +} + /// Return the suffixes to append to a function arg/result name when splitting the given type. static SmallVector getSplitRecordNameSuffixes(Type type) { SmallVector suffixes; @@ -976,7 +982,9 @@ class SplitPodArrayReadArrayOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - static bool legal(ReadArrayOp op) { return !splittablePodArray(op.getArrRefType()); } + static bool legal(ReadArrayOp op) { + return !splittablePodArray(op.getArrRefType()) || shouldDeferPodArrayReadToStep3(op); + } LogicalResult matchAndRewrite( ReadArrayOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter @@ -1940,6 +1948,116 @@ class SplitPodInMemberReadOp : public OpConversionPattern { } }; +static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp); + +/// Collect split leaf arrays from a value materialized back to an aggregate array-of-POD type. +static bool tryCollectMaterializedSplitPodArrayLeafValues( + Value arrayValue, ArrayType arrTy, ArrayRef splitTypes, SmallVectorImpl &leafArrays +) { + auto cast = arrayValue.getDefiningOp(); + if (!cast || cast->getNumResults() != 1 || cast.getResult(0).getType() != arrTy || + cast->getNumOperands() != splitTypes.size()) { + return false; + } + + for (auto [operand, splitType] : llvm::zip_equal(cast.getOperands(), splitTypes)) { + if (operand.getType() != splitType) { + return false; + } + leafArrays.push_back(operand); + } + return true; +} + +/// Collect split leaf arrays for an array-of-POD value backed by a direct `pod.read`. +static bool tryCollectReadPodSplitPodArrayLeafValues( + ReadPodOp readOp, ArrayType arrTy, ArrayRef splitIds, ArrayRef splitTypes, + const VirtualPodValueMap &virtualPods, SmallVectorImpl &leafArrays +) { + if (const VirtualPodLeafMap *podLeafValues = + lookupVirtualPodLeafMap(readOp.getPodRef(), virtualPods)) { + leafArrays.reserve(splitIds.size()); + for (const RecordChain &id : splitIds) { + SmallVector fullChain {readOp.getRecordNameAttr()}; + llvm::append_range(fullChain, id.nameList); + auto it = podLeafValues->find(RecordChain(fullChain)); + if (it == podLeafValues->end() || + it->second.getType() != getFlattenedTypeAlongPath(arrTy, id.nameList)) { + return false; + } + leafArrays.push_back(it->second); + } + return true; + } + + if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { + return tryCollectMaterializedSplitPodArrayLeafValues( + writeOp.getValue(), arrTy, splitTypes, leafArrays + ); + } + + return false; +} + +/// Resolve deferred `array.read` from `pod.read`-produced array-of-POD values. +class ResolvePodReadBackedArrayReadOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + +public: + ResolvePodReadBackedArrayReadOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + static bool canResolve(ReadArrayOp op, const VirtualPodValueMap &virtualPods) { + if (!shouldDeferPodArrayReadToStep3(op)) { + return false; + } + + ArrayType arrTy = op.getArrRefType(); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector ignoredLeafArrays; + return tryCollectReadPodSplitPodArrayLeafValues( + llvm::cast(op.getArrRef().getDefiningOp()), arrTy, splitIds, splitTypes, + virtualPods, ignoredLeafArrays + ); + } + + LogicalResult matchAndRewrite( + ReadArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + auto fieldRead = op.getArrRef().getDefiningOp(); + if (!fieldRead) { + return failure(); + } + + ArrayType arrTy = op.getArrRefType(); + PodType podTy = llvm::cast(arrTy.getElementType()); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector splitLeafArrays; + if (!tryCollectReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, splitLeafArrays + )) { + return failure(); + } + + SmallVector indices(adaptor.getIndices().begin(), adaptor.getIndices().end()); + DenseMap leafValues; + for (auto [id, leafArray] : llvm::zip_equal(splitIds, splitLeafArrays)) { + leafValues[id] = genArrayRead(rewriter, op.getLoc(), leafArray, indices); + } + + NewPodOp pod = rewriter.create(op.getLoc(), podTy); + virtualPods[pod] = std::move(leafValues); + rewriter.replaceOp(op, pod); + return success(); + } +}; + /// Resolve reads from a virtual POD placeholder without materializing the whole aggregate. class ResolveVirtualPodReadOp : public OpConversionPattern { VirtualPodValueMap &virtualPods; @@ -1998,6 +2116,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM patterns.add( ctx, symTables, memberRepMap, virtualPods ); + patterns.add(ctx, virtualPods); patterns.add(ctx, virtualPods); ConversionTarget target(*ctx); @@ -2009,6 +2128,9 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodInCallOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberReadOp::legal); + target.addDynamicallyLegalOp([&virtualPods](ReadArrayOp op) { + return !ResolvePodReadBackedArrayReadOp::canResolve(op, virtualPods); + }); target.addDynamicallyLegalOp([&virtualPods](ReadPodOp op) { return !canResolveVirtualPodRead(op, virtualPods); }); diff --git a/test/Transforms/PodToScalar/array_read_from_pod_field.llzk b/test/Transforms/PodToScalar/array_read_from_pod_field.llzk new file mode 100644 index 000000000..e5062bebd --- /dev/null +++ b/test/Transforms/PodToScalar/array_read_from_pod_field.llzk @@ -0,0 +1,22 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<2 x !Item>]> +module attributes {llzk.lang} { + function.def @read_item(%p: !Outer, %i: index) -> index { + %items = pod.read %p[@items] : !Outer, !array.type<2 x !Item> + %item = array.read %items[%i] : !array.type<2 x !Item>, !Item + %x = pod.read %item[@x] : !Item, index + function.return %x : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @read_item(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[IDX:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[UNUSED0:[0-9a-zA-Z_\.]+]] = array.read %[[ARR]]{{\[}}%[[C0]]] : <2 x index>, index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[UNUSED1:[0-9a-zA-Z_\.]+]] = array.read %[[ARR]]{{\[}}%[[C1]]] : <2 x index>, index +// CHECK-NEXT: %[[SEL:[0-9a-zA-Z_\.]+]] = array.read %[[ARR]]{{\[}}%[[IDX]]] : <2 x index>, index +// CHECK-NEXT: function.return %[[SEL]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } From 3449ad66259efd2fea566915fdac4065880f7732 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 13:17:02 -0500 Subject: [PATCH 20/36] fix: Preserve affine operands for leaf-array dimensions --- .../POD/Transforms/PodToScalarPass.cpp | 398 ++++++++++++++---- .../array_new_affine_leaf_array.llzk | 57 +++ .../array_new_affine_leaf_array_conflict.llzk | 20 + 3 files changed, 400 insertions(+), 75 deletions(-) create mode 100644 test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk create mode 100644 test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index da54f8038..e0214aae7 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -82,6 +82,7 @@ #include "llzk/Dialect/POD/IR/Types.h" #include "llzk/Dialect/POD/Transforms/TransformationPasses.h" #include "llzk/Dialect/Polymorphic/IR/Dialect.h" +#include "llzk/Dialect/Polymorphic/IR/Ops.h" #include "llzk/Dialect/RAM/IR/Dialect.h" #include "llzk/Dialect/String/IR/Dialect.h" #include "llzk/Dialect/Struct/IR/Ops.h" @@ -104,6 +105,7 @@ #include #include +#include // Include the generated base pass class definitions. namespace llzk::pod { @@ -117,6 +119,7 @@ using namespace llzk::array; using namespace llzk::pod; using namespace llzk::function; using namespace llzk::component; +using namespace llzk::polymorphic; #define DEBUG_TYPE "llzk-pod-to-scalar" @@ -155,6 +158,57 @@ template <> struct DenseMapInfo { namespace { +/// Return whether the given read/write access targets the same POD record. +inline static bool isSamePodRecord(ReadPodOp readOp, Value podRef, StringAttr recordName) { + return readOp.getPodRef() == podRef && readOp.getRecordNameAttr() == recordName; +} + +/// Return whether the given read/write access targets the same POD record. +inline static bool isSamePodRecord(WritePodOp writeOp, Value podRef, StringAttr recordName) { + return writeOp.getPodRef() == podRef && writeOp.getRecordNameAttr() == recordName; +} + +/// Return whether `op` contains a nested write to `podRef.recordName`. +static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recordName) { + return walkContainsMatch(op, [&](WritePodOp writeOp) { + return writeOp.getOperation() != &op && isSamePodRecord(writeOp, podRef, recordName); + }); +} + +/// Return whether `op` contains any read from `podRef.recordName`. +static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) { + return walkContainsMatch(op, [&podRef, &recordName](ReadPodOp readOp) { + return isSamePodRecord(readOp, podRef, recordName); + }); +} + +/// Return whether `op` or any nested operation uses `value` as an operand. +static bool hasValueUse(Operation &op, Value value) { + return walkContainsMatch(op, [&value](Operation *nestedOp) { + return llvm::is_contained(nestedOp->getOperands(), value); + }); +} + +/// Return the nearest preceding same-record write that can be forwarded to `readOp`. +/// +/// This fold is intentionally conservative: it only forwards through intervening operations that do +/// not use the POD value at all. That keeps the rewrite local and avoids reasoning about other +/// whole-POD uses or record accesses that may observe mutation ordering. +static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp) { + Value podRef = readOp.getPodRef(); + StringAttr recordName = readOp.getRecordNameAttr(); + + for (Operation *op = readOp->getPrevNode(); op; op = op->getPrevNode()) { + if (!hasValueUse(*op, podRef)) { + continue; + } + + auto writeOp = dyn_cast(op); + return writeOp && isSamePodRecord(writeOp, podRef, recordName) ? writeOp : nullptr; + } + return nullptr; +} + /// If the given PodType can be split into scalars (always true for PodType), return it. inline static PodType splittablePod(PodType pt) { return pt; } @@ -364,6 +418,19 @@ static SmallVector getSplitPodArrayRecordNameSuffixes(Type type) { return suffixes; } +/// Insert a `poly.unifiable_cast` when a rewritten value must match a more specific type. +/// +/// This is the common bridge between wildcard-backed storage values and the more precise types +/// expected by surrounding rewritten IR. The cast is only emitted when the source and target +/// types unify and differ syntactically. +static Value castValueToTypeIfNeeded(OpBuilder &bldr, Location loc, Value value, Type targetType) { + if (value.getType() == targetType) { + return value; + } + assert(typesUnify(value.getType(), targetType) && "expected compatible rewritten types"); + return bldr.create(loc, targetType, value); +} + /// Create a `pod.read` for one record of `podRef`. inline static ReadPodOp genRead(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName) { @@ -375,7 +442,11 @@ genRead(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName) { /// Create a `pod.write` for one record of `podRef`. inline static WritePodOp genWrite(OpBuilder &bldr, Location loc, Value podRef, StringAttr recordName, Value value) { - return bldr.create(loc, podRef, recordName, value); + Type recordType = + llvm::cast(podRef.getType()).getRecordMap().lookup(recordName.getValue()); + return bldr.create( + loc, podRef, recordName, castValueToTypeIfNeeded(bldr, loc, value, recordType) + ); } /// Return the single converted value from a 1:N adaptor range. @@ -428,6 +499,41 @@ static SmallVector flattenConvertedValues(RangeOfRanges ranges) { return values; } +/// Replace any AffineMap-backed array dimensions nested within `type` with wildcard `?` dims. +/// +/// This preserves the overall array nesting while erasing only the affine-map dimensions that +/// cannot always be witnessed after flattening a POD leaf array into a split array value. +static Type replaceAffineMapArrayDimsWithWildcards(Type type) { + auto arrTy = llvm::dyn_cast(type); + if (!arrTy) { + return type; + } + + Builder builder(arrTy.getContext()); + SmallVector dims; + dims.reserve(arrTy.getDimensionSizes().size()); + for (Attribute dimSize : arrTy.getDimensionSizes()) { + if (llvm::isa(dimSize)) { + dims.push_back(builder.getIndexAttr(ShapedType::kDynamic)); + } else { + dims.push_back(dimSize); + } + } + + return arrTy.cloneWith(replaceAffineMapArrayDimsWithWildcards(arrTy.getElementType()), dims); +} + +/// Return the wildcard-backed storage split type for one flattened POD leaf. +/// +/// The precise split type preserves the original affine maps in the flattened leaf array. The +/// storage split type uses the same outer shape but replaces hidden leaf-array affine dims with +/// `?` until a matching instantiation can be recovered from concrete leaf-array values. +static ArrayType getSplitPodArrayStorageType(ArrayType arrTy, ArrayRef recordChain) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + Type leafType = getFlattenedTypeAlongPath(elemPodTy, recordChain); + return flattenArrayElementType(arrTy, replaceAffineMapArrayDimsWithWildcards(leafType)); +} + /// Create an array value that callers can fully initialize via explicit writes or inserts. /// /// Use `llzk.nondet` as the base when affine-map dimensions are present because `array.new` @@ -440,6 +546,104 @@ inline static Value createWritableArrayValue(OpBuilder &bldr, Location loc, Arra } } +/// Store the affine-map operand groups needed to rebuild one concrete array instantiation. +/// +/// The layout mirrors `array.new`: `mapOperandStorage` keeps each instantiation group separately, +/// and `numDimsPerMap` records how many values in each group are dimensional arguments. +struct ArrayInstantiationInfo { + SmallVector> mapOperandStorage; + SmallVector numDimsPerMap; +}; + +/// Return `true` iff two recovered array instantiations can be rebuilt identically. +static bool equivalentArrayInstantiationInfo( + const ArrayInstantiationInfo &lhs, const ArrayInstantiationInfo &rhs +) { + if (lhs.numDimsPerMap != rhs.numDimsPerMap || + lhs.mapOperandStorage.size() != rhs.mapOperandStorage.size()) { + return false; + } + + for (auto [lhsGroup, rhsGroup] : llvm::zip_equal(lhs.mapOperandStorage, rhs.mapOperandStorage)) { + if (lhsGroup != rhsGroup) { + return false; + } + } + + return true; +} + +/// Try to recover affine-map instantiation operands from a concrete array-producing value. +/// +/// This peels compatibility casts, follows simple `pod.read` to dominating `pod.write` +/// forwarding, and succeeds only when the value ultimately traces back to a concrete +/// `array.new` carrying the instantiation groups. +static std::optional tryGetArrayInstantiationInfo(Value value) { + while (auto cast = value.getDefiningOp()) { + value = cast.getInput(); + } + + if (ReadPodOp read = value.getDefiningOp()) { + if (WritePodOp write = findNearestForwardableWriteInBlock(read)) { + return tryGetArrayInstantiationInfo(write.getValue()); + } + return std::nullopt; + } + + auto create = value.getDefiningOp(); + if (!create) { + return std::nullopt; + } + + ArrayInstantiationInfo info; + info.mapOperandStorage.reserve(create.getMapOperands().size()); + for (OperandRange group : create.getMapOperands()) { + info.mapOperandStorage.emplace_back(group.begin(), group.end()); + } + + if (DenseI32ArrayAttr numDimsPerMap = create.getNumDimsPerMapAttr()) { + llvm::append_range(info.numDimsPerMap, numDimsPerMap.asArrayRef()); + } + + return info; +} + +/// Describe whether a set of leaf arrays shares one recoverable instantiation. +enum class CommonArrayInstantiationStatus : std::uint8_t { + unavailable, + inferred, + conflict, +}; + +/// Recover a single shared affine-map instantiation from all of `values`, if one exists. +/// +/// Returns `inferred` when every value resolves to the same concrete `array.new` +/// instantiation, `unavailable` when any value has no recoverable witness, and `conflict` +/// when the recovered instantiations disagree. +static CommonArrayInstantiationStatus +inferCommonArrayInstantiation(ArrayRef values, ArrayInstantiationInfo &result) { + bool initialized = false; + for (Value value : values) { + std::optional info = tryGetArrayInstantiationInfo(value); + if (!info) { + return CommonArrayInstantiationStatus::unavailable; + } + + if (!initialized) { + result = std::move(*info); + initialized = true; + continue; + } + + if (!equivalentArrayInstantiationInfo(result, *info)) { + return CommonArrayInstantiationStatus::conflict; + } + } + + return initialized ? CommonArrayInstantiationStatus::inferred + : CommonArrayInstantiationStatus::unavailable; +} + /// Generate `arith.constant` indices for one static array element position. static SmallVector genArrayIndexConstants(OpBuilder &bldr, Location loc, ArrayAttr index) { SmallVector indices; @@ -483,6 +687,7 @@ genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayRef ind Type t = arrayRef.getType(); assert(llvm::isa(t) && "array access must target an array type"); ArrayType arrTy = llvm::cast(t); + value = castValueToTypeIfNeeded(bldr, loc, value, getArraySelectionType(arrTy, indices.size())); if (indices.size() == arrTy.getDimensionSizes().size()) { bldr.create(loc, arrayRef, indices, value); return; @@ -599,15 +804,19 @@ lookupVirtualPodLeafMap(Value podValue, const VirtualPodValueMap &virtualPods) { } /// Collect flattened POD leaf values in canonical traversal order. -static SmallVector -orderedVirtualPodLeafValues(PodType podTy, const VirtualPodLeafMap &leafValues) { +static SmallVector orderedVirtualPodLeafValues( + PodType podTy, Location loc, OpBuilder &bldr, const VirtualPodLeafMap &leafValues +) { SmallVector orderedValues; SmallVector recordChain; - forEachPodLeaf(podTy, recordChain, [&leafValues, &orderedValues](const RecordChain &id, Type) { + forEachPodLeaf( + podTy, recordChain, + [&leafValues, &orderedValues, &bldr, loc](const RecordChain &id, Type leafType) { auto it = leafValues.find(id); assert(it != leafValues.end() && "missing virtual POD leaf value"); - orderedValues.push_back(it->second); - }); + orderedValues.push_back(castValueToTypeIfNeeded(bldr, loc, it->second, leafType)); + } + ); return orderedValues; } @@ -667,7 +876,9 @@ static void processInputOperand( if (PodType pt = splittablePod(operand.getType())) { if (virtualPods) { if (const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(operand, *virtualPods)) { - llvm::append_range(newOperands, orderedVirtualPodLeafValues(pt, *leafValues)); + llvm::append_range( + newOperands, orderedVirtualPodLeafValues(pt, loc, rewriter, *leafValues) + ); return; } } @@ -864,6 +1075,10 @@ step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &me } /// Type converter that replaces each array-of-POD type with one parallel array type per POD leaf. +/// +/// Besides splitting result types, this also materializes compatibility casts between precise +/// split array types and wildcard-backed storage split types when target or block-argument +/// conversion needs to cross that boundary. class PodArrayTypeConverter : public TypeConverter { public: PodArrayTypeConverter() { @@ -877,6 +1092,16 @@ class PodArrayTypeConverter : public TypeConverter { return success(); } ); + + auto materializeCast = [](OpBuilder &bldr, Type targetType, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1 || !typesUnify(inputs.front().getType(), targetType)) { + return {}; + } + return castValueToTypeIfNeeded(bldr, loc, inputs.front(), targetType); + }; + addTargetMaterialization(materializeCast); + addArgumentMaterialization(materializeCast); } }; @@ -902,6 +1127,15 @@ class SplitPodArrayNonDetOp : public OpConversionPattern { }; /// Split `array.new` of array-of-POD type into one `array.new` per parallel leaf array. +/// +/// For each leaf, the precise split type preserves the original affine maps in the flattened leaf +/// array. When hidden leaf-array affine dims have no direct witness, the rewrite may first build a +/// wildcard-backed storage split type with the same outer shape and cast back to the precise type. +/// +/// Uninitialized `array.new` uses that storage fallback directly when needed. Explicit-element +/// `array.new` tries to infer one shared affine-map instantiation from all leaf arrays so it can +/// materialize the precise split type immediately. If different elements imply conflicting +/// instantiations, the rewrite remains a hard failure. class SplitPodArrayCreateArrayOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -924,9 +1158,12 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr(); if (isNullOrEmpty(numDimsPerMap)) { if (adaptor.getElements().empty()) { - for (Type splitType : splitTypes) { + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + ArrayType preciseSplitType = llvm::cast(splitType); + ArrayType storageSplitType = getSplitPodArrayStorageType(arrTy, id.nameList); + Value splitArray = rewriter.create(op.getLoc(), storageSplitType); replacements.push_back( - rewriter.create(op.getLoc(), llvm::cast(splitType)) + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitArray, preciseSplitType) ); } rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); @@ -945,14 +1182,55 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { // element at a time so each leaf array becomes a subarray insert rather than a malformed // inline operand to the flattened `array.new`. for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { - Value splitArray = - createWritableArrayValue(rewriter, op.getLoc(), llvm::cast(splitType)); - for (auto [index, elementRange] : llvm::zip_equal(*elementIndices, adaptor.getElements())) { + ArrayType preciseSplitType = llvm::cast(splitType); + ArrayType storageSplitType = getSplitPodArrayStorageType(arrTy, id.nameList); + + SmallVector leafValues; + leafValues.reserve(adaptor.getElements().size()); + for (ValueRange elementRange : adaptor.getElements()) { Value element = getSingleConvertedValue(elementRange); - Value leafValue = genReadAlongPath(rewriter, op.getLoc(), element, id); + leafValues.push_back(genReadAlongPath(rewriter, op.getLoc(), element, id)); + } + + ArrayType materializedType = storageSplitType; + Value splitArray; + if (storageSplitType != preciseSplitType) { + ArrayInstantiationInfo instantiationInfo; + switch (inferCommonArrayInstantiation(leafValues, instantiationInfo)) { + case CommonArrayInstantiationStatus::conflict: + // TODO: this POD could be promoted to a complete `struct.def` but that's not easy. + op.emitOpError( + "with POD elements having conflicting affine map instantiations cannot be promoted " + "to higher dimensional array" + ); + return failure(); + case CommonArrayInstantiationStatus::inferred: { + materializedType = preciseSplitType; + SmallVector mapOperands; + mapOperands.reserve(instantiationInfo.mapOperandStorage.size()); + for (const SmallVector &values : instantiationInfo.mapOperandStorage) { + mapOperands.push_back(values); + } + splitArray = rewriter.create( + op.getLoc(), materializedType, mapOperands, instantiationInfo.numDimsPerMap + ); + break; + } + case CommonArrayInstantiationStatus::unavailable: + break; + } + } + + if (!splitArray) { + splitArray = createWritableArrayValue(rewriter, op.getLoc(), materializedType); + } + + for (auto [index, leafValue] : llvm::zip_equal(*elementIndices, leafValues)) { genArrayWrite(rewriter, op.getLoc(), splitArray, index, leafValue); } - replacements.push_back(splitArray); + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitArray, preciseSplitType) + ); } } else { SmallVector> mapOperandStorage; @@ -965,10 +1243,15 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { for (const SmallVector &values : mapOperandStorage) { mapOperands.push_back(values); } - for (Type splitType : splitTypes) { - replacements.push_back(rewriter.create( - op.getLoc(), llvm::cast(splitType), mapOperands, numDimsPerMap - )); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + ArrayType preciseSplitType = llvm::cast(splitType); + ArrayType storageSplitType = getSplitPodArrayStorageType(arrTy, id.nameList); + Value splitArray = rewriter.create( + op.getLoc(), storageSplitType, mapOperands, numDimsPerMap + ); + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitArray, preciseSplitType) + ); } } @@ -1948,9 +2231,10 @@ class SplitPodInMemberReadOp : public OpConversionPattern { } }; -static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp); - -/// Collect split leaf arrays from a value materialized back to an aggregate array-of-POD type. +/// Collect precise split leaf arrays from a value re-materialized as an aggregate array-of-POD. +/// +/// This recognizes the temporary aggregate form produced by dialect conversion casts and unwraps +/// it back into the parallel split arrays expected by the late pod-array read resolvers. static bool tryCollectMaterializedSplitPodArrayLeafValues( Value arrayValue, ArrayType arrTy, ArrayRef splitTypes, SmallVectorImpl &leafArrays ) { @@ -1969,7 +2253,11 @@ static bool tryCollectMaterializedSplitPodArrayLeafValues( return true; } -/// Collect split leaf arrays for an array-of-POD value backed by a direct `pod.read`. +/// Collect precise split leaf arrays for an array-of-POD value backed by a direct `pod.read`. +/// +/// This first consults virtual POD leaf storage and, if unavailable, falls back to forwarding +/// through a dominating same-record `pod.write` whose value was previously materialized as split +/// arrays. static bool tryCollectReadPodSplitPodArrayLeafValues( ReadPodOp readOp, ArrayType arrTy, ArrayRef splitIds, ArrayRef splitTypes, const VirtualPodValueMap &virtualPods, SmallVectorImpl &leafArrays @@ -1982,7 +2270,7 @@ static bool tryCollectReadPodSplitPodArrayLeafValues( llvm::append_range(fullChain, id.nameList); auto it = podLeafValues->find(RecordChain(fullChain)); if (it == podLeafValues->end() || - it->second.getType() != getFlattenedTypeAlongPath(arrTy, id.nameList)) { + !typesUnify(it->second.getType(), getFlattenedTypeAlongPath(arrTy, id.nameList))) { return false; } leafArrays.push_back(it->second); @@ -2000,6 +2288,10 @@ static bool tryCollectReadPodSplitPodArrayLeafValues( } /// Resolve deferred `array.read` from `pod.read`-produced array-of-POD values. +/// +/// When step 2 defers a read because the array-of-POD came from a POD record, this pattern +/// reconstructs the per-leaf split arrays, performs the array read on each leaf array, and then +/// rebuilds the element POD virtually instead of materializing the whole aggregate array first. class ResolvePodReadBackedArrayReadOp : public OpConversionPattern { VirtualPodValueMap &virtualPods; @@ -2059,6 +2351,9 @@ class ResolvePodReadBackedArrayReadOp : public OpConversionPattern }; /// Resolve reads from a virtual POD placeholder without materializing the whole aggregate. +/// +/// This pattern answers `pod.read` directly from virtual leaf storage, rebuilding nested POD +/// subrecords on demand and casting scalar leaves back to the precise record type when needed. class ResolveVirtualPodReadOp : public OpConversionPattern { VirtualPodValueMap &virtualPods; @@ -2097,7 +2392,11 @@ class ResolveVirtualPodReadOp : public OpConversionPattern { return failure(); } - rewriter.replaceOp(op, leafValues->at(RecordChain(prefix))); + rewriter.replaceOp( + op, castValueToTypeIfNeeded( + rewriter, op.getLoc(), leafValues->at(RecordChain(prefix)), recordType + ) + ); return success(); } }; @@ -2153,57 +2452,6 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM return success(); } -/// Return whether the given read/write access targets the same POD record. -inline static bool isSamePodRecord(ReadPodOp readOp, Value podRef, StringAttr recordName) { - return readOp.getPodRef() == podRef && readOp.getRecordNameAttr() == recordName; -} - -/// Return whether the given read/write access targets the same POD record. -inline static bool isSamePodRecord(WritePodOp writeOp, Value podRef, StringAttr recordName) { - return writeOp.getPodRef() == podRef && writeOp.getRecordNameAttr() == recordName; -} - -/// Return whether `op` contains a nested write to `podRef.recordName`. -static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recordName) { - return walkContainsMatch(op, [&](WritePodOp writeOp) { - return writeOp.getOperation() != &op && isSamePodRecord(writeOp, podRef, recordName); - }); -} - -/// Return whether `op` contains any read from `podRef.recordName`. -static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) { - return walkContainsMatch(op, [&](ReadPodOp readOp) { - return isSamePodRecord(readOp, podRef, recordName); - }); -} - -/// Return whether `op` or any nested operation uses `value` as an operand. -static bool hasValueUse(Operation &op, Value value) { - return walkContainsMatch(op, [&value](Operation *nestedOp) { - return llvm::is_contained(nestedOp->getOperands(), value); - }); -} - -/// Return the nearest preceding same-record write that can be forwarded to `readOp`. -/// -/// This fold is intentionally conservative: it only forwards through intervening operations that do -/// not use the POD value at all. That keeps the rewrite local and avoids reasoning about other -/// whole-POD uses or record accesses that may observe mutation ordering. -static WritePodOp findNearestForwardableWriteInBlock(ReadPodOp readOp) { - Value podRef = readOp.getPodRef(); - StringAttr recordName = readOp.getRecordNameAttr(); - - for (Operation *op = readOp->getPrevNode(); op; op = op->getPrevNode()) { - if (!hasValueUse(*op, podRef)) { - continue; - } - - auto writeOp = dyn_cast(op); - return writeOp && isSamePodRecord(writeOp, podRef, recordName) ? writeOp : nullptr; - } - return nullptr; -} - /// Return whether the read is preceded by a write to the same pod record within its block. static bool hasEarlierWriteInBlock(ReadPodOp readOp) { Value podRef = readOp.getPodRef(); diff --git a/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk b/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk new file mode 100644 index 000000000..54873c893 --- /dev/null +++ b/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk @@ -0,0 +1,57 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +#inner = affine_map<()[s0] -> (s0 + 1)> +!Elem = !pod.type<[@vals: !array.type<#inner x index>]> +!ElemArray = !array.type<#outer x !Elem> +module attributes {llzk.lang} { + function.def @uninitialized(%n: index) -> !ElemArray { + // The affine map initializer here is for the `#outer` map in `!ElemArray`, not the `#inner` map in `!Elem`. + // The `#inner` map can only be initialized via an `array.new` that is then written to the pod which is not + // present in this example. Thus, when `pod-to-scalar` combines the arrays to create a single, 2-D array, + // the second dimension size must be `?` because no affine instantiation exists to provide as an affine + // instantiation. However, the return value itself is unchanged so a `unifiable_cast` is used to convert. + %arr = array.new{()[%n]} : !ElemArray + function.return %arr : !ElemArray + } +} +// CHECK: #[[$M1:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK: #[[$M2:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @uninitialized(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> !array.type<#[[$M1]],#[[$M2]] x index> { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <#[[$M1]],? x index> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_1]] : (!array.type<#[[$M1]],? x index>) -> !array.type<#[[$M1]],#[[$M2]] x index> +// CHECK-NEXT: function.return %[[VAL_2]] : !array.type<#[[$M1]],#[[$M2]] x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +#inner = affine_map<()[s0] -> (s0 + 1)> +!Leaf = !array.type<#inner x index> +!Elem = !pod.type<[@vals: !Leaf]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @materialized(%n: index) -> !ElemArray { + %vals0 = array.new{()[%n]} : !Leaf + %vals1 = array.new{()[%n]} : !Leaf + %lhs = pod.new()[%n] : !Elem + pod.write %lhs[@vals] = %vals0 : !Elem, !Leaf + %rhs = pod.new()[%n] : !Elem + pod.write %rhs[@vals] = %vals1 : !Elem, !Leaf + %arr = array.new %lhs, %rhs : !ElemArray + function.return %arr : !ElemArray + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @materialized(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> !array.type<2,#[[$M]] x index> { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <#[[$M]] x index> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <#[[$M]] x index> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = array.new{(){{\[}}%[[VAL_0]]]} : <2,#[[$M]] x index> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.insert %[[VAL_3]]{{\[}}%[[VAL_4]]] = %[[VAL_1]] : <2,#[[$M]] x index>, <#[[$M]] x index> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.insert %[[VAL_3]]{{\[}}%[[VAL_5]]] = %[[VAL_2]] : <2,#[[$M]] x index>, <#[[$M]] x index> +// CHECK-NEXT: function.return %[[VAL_3]] : !array.type<2,#[[$M]] x index> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk b/test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk new file mode 100644 index 000000000..da5fa9a2c --- /dev/null +++ b/test/Transforms/PodToScalar/array_new_affine_leaf_array_conflict.llzk @@ -0,0 +1,20 @@ +// RUN: llzk-opt %s -split-input-file -llzk-pod-to-scalar -verify-diagnostics + +#inner = affine_map<()[s0] -> (s0 + 1)> +!InnerArr = !array.type<#inner x index> +!Elem = !pod.type<[@vals: !InnerArr]> +!ElemArray = !array.type<2 x !Elem> +module attributes {llzk.lang} { + function.def @conflict(%m: index, %n: index) -> !ElemArray { + %vals0 = array.new{()[%m]} : !InnerArr + %vals1 = array.new{()[%n]} : !InnerArr + %lhs = pod.new()[%m] : !Elem + pod.write %lhs[@vals] = %vals0 : !Elem, !InnerArr + %rhs = pod.new()[%n] : !Elem + pod.write %rhs[@vals] = %vals1 : !Elem, !InnerArr + // expected-error@+2 {{'array.new' op with POD elements having conflicting affine map instantiations cannot be promoted to higher dimensional array}} + // expected-error@+1 {{failed to legalize operation 'array.new'}} + %arr = array.new %lhs, %rhs : !ElemArray + function.return %arr : !ElemArray + } +} From 04c59afc01bd375c58dd85033040ca573aadb290 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 15:04:28 -0500 Subject: [PATCH 21/36] fix: Preserve original rank for array.len --- .../POD/Transforms/PodToScalarPass.cpp | 29 +++++++++++++++---- test/Transforms/PodToScalar/array_length.llzk | 18 ++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index e0214aae7..5431d373f 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1635,7 +1635,28 @@ class SplitPodArrayInEmitContainmentOp : public OpConversionPattern(arrRef.getType()); + assert(arrTy && "converted array-of-POD operand must stay an array"); + if (arrTy.getDimensionSizes().size() == originalRank) { + return arrRef; + } + } + + return materializeArrayLengthCarrier(op.getArrRef(), op.getArrRefType(), op.getLoc(), rewriter); +} + +/// Replace `array.length` on an array-of-POD with an equivalent rank-preserving array value. class SplitPodArrayLengthOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1648,11 +1669,7 @@ class SplitPodArrayLengthOp : public OpConversionPattern { if (legal(op)) { return failure(); } - Value arrRef = adaptor.getArrRef().empty() - ? materializeArrayLengthCarrier( - op.getArrRef(), op.getArrRefType(), op.getLoc(), rewriter - ) - : adaptor.getArrRef().front(); + Value arrRef = selectArrayLengthShapeSource(op, adaptor.getArrRef(), rewriter); rewriter.replaceOpWithNewOp( op, arrRef, getSingleConvertedValue(adaptor.getDim()) ); diff --git a/test/Transforms/PodToScalar/array_length.llzk b/test/Transforms/PodToScalar/array_length.llzk index 8e44ec94a..2ea7734d1 100644 --- a/test/Transforms/PodToScalar/array_length.llzk +++ b/test/Transforms/PodToScalar/array_length.llzk @@ -15,6 +15,24 @@ module attributes {llzk.lang} { // CHECK-NEXT: } // ----- +// Tests: preserve the `array.len` semantics so a dynamic dimension selection +// cannot observe dimensions that did not exist before pod flattening. +!Elem = !pod.type<[@vals: !array.type<3 x index>]> +module attributes {llzk.lang} { + function.def @len_leaf_array_dynamic_dim(%arr: !array.type<2 x !Elem>, %dim: index) -> index { + %len = array.len %arr, %dim : !array.type<2 x !Elem> + function.return %len : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_leaf_array_dynamic_dim(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<2,3 x index>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[SHAPE:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[SHAPE]], %[[DIM]] : <2 x index> +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + module attributes {llzk.lang} { function.def @len_empty_leaf_static_create(%dim: index) -> index { %arr = array.new : !array.type<4 x !pod.type<[]>> From ca87a4e2b6df74914083d49ba8d13852cfb83633 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 15:22:58 -0500 Subject: [PATCH 22/36] fix: Support dynamic nested POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 86 ++++++++++++++++++- .../nonstatic_nested_pod_array.llzk | 55 ++++++++++++ 2 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 5431d373f..a869db80e 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -702,6 +702,45 @@ genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index, Va genArrayWrite(bldr, loc, arrayRef, indices, value); } +/// Collect split leaf arrays that are already available for an aggregate array-of-POD value. +/// +/// This peels compatibility casts and forwards through a dominating same-record `pod.write` so +/// nested POD scalarization can reuse the split-array representation already produced elsewhere in +/// the pass instead of re-materializing dynamic arrays element-by-element. +static bool tryCollectDirectSplitPodArrayLeafValues( + Value arrayValue, ArrayType arrTy, ArrayRef splitTypes, SmallVectorImpl &leafArrays +) { + while (auto cast = arrayValue.getDefiningOp()) { + arrayValue = cast.getInput(); + } + + if (auto cast = arrayValue.getDefiningOp()) { + if (cast->getNumResults() != 1 || cast.getResult(0).getType() != arrTy || + cast->getNumOperands() != splitTypes.size()) { + return false; + } + + leafArrays.reserve(splitTypes.size()); + for (auto [operand, splitType] : llvm::zip_equal(cast.getOperands(), splitTypes)) { + if (operand.getType() != splitType) { + return false; + } + leafArrays.push_back(operand); + } + return true; + } + + if (ReadPodOp readOp = arrayValue.getDefiningOp()) { + if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { + return tryCollectDirectSplitPodArrayLeafValues( + writeOp.getValue(), arrTy, splitTypes, leafArrays + ); + } + } + + return false; +} + /// Read one flattened POD leaf, including leaves that live inside an array-of-POD record. static Value genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef recordChain) { @@ -716,8 +755,33 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef(getFlattenedTypeAlongPath(valueType, recordChain)); + + if (!arrTy.hasStaticShape()) { + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector leafArrays; + if (tryCollectDirectSplitPodArrayLeafValues(value, arrTy, splitTypes, leafArrays)) { + auto it = llvm::find(splitIds, RecordChain(recordChain)); + assert(it != splitIds.end() && "record path must name a flattened POD array leaf"); + return leafArrays[std::distance(splitIds.begin(), it)]; + } + + if (ReadPodOp readOp = value.getDefiningOp()) { + if (llvm::isa(readOp.getPodRef().getDefiningOp()) && + !findNearestForwardableWriteInBlock(readOp)) { + return createWritableArrayValue(bldr, loc, splitArrTy); + } + } + + llvm_unreachable( + "non-static nested array-of-POD scalarization requires split-array backing or an " + "uninitialized pod field" + ); + } + auto subIndices = arrTy.getSubelementIndices(); assert(subIndices && "static-shape arrays must provide subelement indices"); @@ -757,7 +821,25 @@ static Value rebuildFlattenedPodRecord( } if (ArrayType arrTy = splittablePodArray(recordType)) { - assert(arrTy.hasStaticShape() && "nested array-of-POD scalarization requires a static shape"); + if (!arrTy.hasStaticShape()) { + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + SmallVector leafArrays; + leafArrays.reserve(splitIds.size()); + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + SmallVector fullChain(recordChain.begin(), recordChain.end()); + llvm::append_range(fullChain, id.nameList); + auto it = leafValues.find(RecordChain(fullChain)); + assert(it != leafValues.end() && "missing flattened POD array leaf value"); + leafArrays.push_back(castValueToTypeIfNeeded(bldr, loc, it->second, splitType)); + } + + return bldr.create(loc, TypeRange {arrTy}, leafArrays) + .getResult(0); + } + auto elemPodTy = llvm::cast(arrTy.getElementType()); auto subIndices = arrTy.getSubelementIndices(); assert(subIndices && "static-shape arrays must provide subelement indices"); diff --git a/test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk b/test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk new file mode 100644 index 000000000..ddcf94efe --- /dev/null +++ b/test/Transforms/PodToScalar/nonstatic_nested_pod_array.llzk @@ -0,0 +1,55 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + function.def @return_direct(%n: index) -> !Outer { + %p = pod.new()[%n] : !Outer + function.return %p : !Outer + } +} +// CHECK: #[[$MAP0:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @return_direct(%[[N0:[0-9a-zA-Z_\.]+]]: index) -> !array.type<#[[$MAP0]] x index> { +// CHECK-NEXT: %[[ARR0:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#[[$MAP0]] x index> +// CHECK-NEXT: function.return %[[ARR0]] : !array.type<#[[$MAP0]] x index> +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !Outer + + function.def @compute(%n: index) -> !struct.type<@S> { + %p = pod.new()[%n] : !Outer + %self = struct.new : !struct.type<@S> + struct.writem %self[@m] = %p : !struct.type<@S>, !Outer + function.return %self : !struct.type<@S> + } + + function.def @constrain(%self: !struct.type<@S>, %n: index) { + function.return + } + } +} +// CHECK: #[[$MAP1:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-NEXT: struct.member @m_items_x : !array.type<#[[$MAP1]] x index> +// CHECK-NEXT: function.def @compute(%[[N1:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[SELF:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-NEXT: %[[ARR1:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#[[$MAP1]] x index> +// CHECK-NEXT: struct.writem %[[SELF]][@m_items_x] = %[[ARR1]] : <@S>, !array.type<#[[$MAP1]] x index> +// CHECK-NEXT: function.return %[[SELF]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[SELFARG:[0-9a-zA-Z_\.]+]]: !struct.type<@S>, %[[N2:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From e8bceaa7f1aa345bb42a087cfb9b294fa090faba Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 20:56:33 -0500 Subject: [PATCH 23/36] fix: Handle uninitialized deferred POD-array reads --- .../POD/Transforms/PodToScalarPass.cpp | 103 ++++++++++++------ .../array_read_from_unwritten_pod_field.llzk | 22 ++++ 2 files changed, 94 insertions(+), 31 deletions(-) create mode 100644 test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index a869db80e..e57d9d63e 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -877,6 +877,7 @@ static Value rebuildFlattenedPodRecord( using VirtualPodLeafMap = DenseMap; using VirtualPodValueMap = DenseMap; +using DeferredPodArrayLeafMap = DenseMap>; /// Return the flattened leaf values for `podValue` when it is tracked as a virtual POD. static const VirtualPodLeafMap * @@ -926,6 +927,40 @@ static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &vir return llvm::isa(recType) || !splittablePodArray(recType); } +/// Return whether the read is preceded by a write to the same pod record within its block. +static bool hasEarlierWriteInBlock(ReadPodOp readOp) { + Value podRef = readOp.getPodRef(); + StringAttr recordName = readOp.getRecordNameAttr(); + + for (Operation &op : *readOp->getBlock()) { + if (&op == readOp.getOperation()) { + return false; + } + + if (auto writeOp = dyn_cast(&op)) { + if (isSamePodRecord(writeOp, podRef, recordName)) { + return true; + } + } else if (hasNestedWriteToRecord(op, podRef, recordName)) { + return true; + } + } + return false; +} + +/// Return `true` iff `readOp` names a fresh pod record that has not been initialized or written. +static bool isFreshUnwrittenPodRead(ReadPodOp readOp) { + NewPodOp newPod = readOp.getPodRef().getDefiningOp(); + if (!newPod) { + return false; + } + auto isReadOpRecordName = [&readOp](Attribute attr) { + return attr == readOp.getRecordNameAttr(); + }; + return llvm::none_of(newPod.getInitializedRecords(), isReadOpRecordName) && + !hasEarlierWriteInBlock(readOp); +} + /// Return `true` iff step 2 should defer splitting this array read until POD-aware rewriting. static bool shouldDeferPodArrayReadToStep3(ReadArrayOp op) { return splittablePodArray(op.getArrRefType()) && @@ -2393,10 +2428,15 @@ static bool tryCollectReadPodSplitPodArrayLeafValues( /// rebuilds the element POD virtually instead of materializing the whole aggregate array first. class ResolvePodReadBackedArrayReadOp : public OpConversionPattern { VirtualPodValueMap &virtualPods; + DeferredPodArrayLeafMap &deferredPodArrays; public: - ResolvePodReadBackedArrayReadOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) - : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + ResolvePodReadBackedArrayReadOp( + MLIRContext *ctx, VirtualPodValueMap &virtualPodMap, + DeferredPodArrayLeafMap &deferredPodArrayMap + ) + : OpConversionPattern(ctx), virtualPods(virtualPodMap), + deferredPodArrays(deferredPodArrayMap) {} static bool canResolve(ReadArrayOp op, const VirtualPodValueMap &virtualPods) { if (!shouldDeferPodArrayReadToStep3(op)) { @@ -2404,15 +2444,16 @@ class ResolvePodReadBackedArrayReadOp : public OpConversionPattern } ArrayType arrTy = op.getArrRefType(); + auto fieldRead = llvm::cast(op.getArrRef().getDefiningOp()); SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); SmallVector ignoredLeafArrays; return tryCollectReadPodSplitPodArrayLeafValues( - llvm::cast(op.getArrRef().getDefiningOp()), arrTy, splitIds, splitTypes, - virtualPods, ignoredLeafArrays - ); + fieldRead, arrTy, splitIds, splitTypes, virtualPods, ignoredLeafArrays + ) || + isFreshUnwrittenPodRead(fieldRead); } LogicalResult matchAndRewrite( @@ -2433,7 +2474,30 @@ class ResolvePodReadBackedArrayReadOp : public OpConversionPattern if (!tryCollectReadPodSplitPodArrayLeafValues( fieldRead, arrTy, splitIds, splitTypes, virtualPods, splitLeafArrays )) { - return failure(); + if (!isFreshUnwrittenPodRead(fieldRead)) { + return failure(); + } + + // Reuse one synthetic split-array backing per deferred field read so repeated element reads + // from the same aggregate value see the same unwritten leaf storage. + auto [it, inserted] = deferredPodArrays.try_emplace(fieldRead.getResult()); + splitLeafArrays.assign(it->second.begin(), it->second.end()); + if (inserted) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(fieldRead); + splitLeafArrays.reserve(splitTypes.size()); + for (Type splitType : splitTypes) { + splitLeafArrays.push_back( + createWritableArrayValue(rewriter, op.getLoc(), llvm::cast(splitType)) + ); + } + it->second = splitLeafArrays; + } else { + assert( + splitLeafArrays.size() == splitTypes.size() && + "cached split POD arrays must match the rewritten read arity" + ); + } } SmallVector indices(adaptor.getIndices().begin(), adaptor.getIndices().end()); @@ -2506,6 +2570,7 @@ static LogicalResult step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) { MLIRContext *ctx = modOp.getContext(); VirtualPodValueMap virtualPods; + DeferredPodArrayLeafMap deferredPodArrays; RewritePatternSet patterns(ctx); patterns.add(ctx); @@ -2514,7 +2579,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM patterns.add( ctx, symTables, memberRepMap, virtualPods ); - patterns.add(ctx, virtualPods); + patterns.add(ctx, virtualPods, deferredPodArrays); patterns.add(ctx, virtualPods); ConversionTarget target(*ctx); @@ -2551,30 +2616,6 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM return success(); } -/// Return whether the read is preceded by a write to the same pod record within its block. -static bool hasEarlierWriteInBlock(ReadPodOp readOp) { - Value podRef = readOp.getPodRef(); - StringAttr recordName = readOp.getRecordNameAttr(); - - for (Operation &op : *readOp->getBlock()) { - if (&op == readOp.getOperation()) { - return false; - } - - if (auto writeOp = dyn_cast(&op)) { - if (isSamePodRecord(writeOp, podRef, recordName)) { - return true; - } - continue; - } - - if (hasNestedWriteToRecord(op, podRef, recordName)) { - return true; - } - } - return false; -} - /// Return whether `value` is defined within `ancestor` or one of its nested regions. /// /// Values defined inside a control-flow operation cannot be hoisted across that operation without diff --git a/test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk b/test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk new file mode 100644 index 000000000..33be2b4af --- /dev/null +++ b/test/Transforms/PodToScalar/array_read_from_unwritten_pod_field.llzk @@ -0,0 +1,22 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + function.def @read_unwritten_dynamic(%n: index, %i: index) -> index { + %p = pod.new()[%n] : !Outer + %items = pod.read %p[@items] : !Outer, !array.type<#outer x !Item> + %item = array.read %items[%i] : !array.type<#outer x !Item>, !Item + %x = pod.read %item[@x] : !Item, index + function.return %x : index + } +} +// CHECK: #[[$ATTR_0:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @read_unwritten_dynamic(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#[[$ATTR_0]] x index> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_2]]{{\[}}%[[VAL_1]]] : <#[[$ATTR_0]] x index>, index +// CHECK-NEXT: function.return %[[VAL_3]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } From 0b804d4bffbd8b6deb3d83c86cc7acbb1154d097 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 21:07:50 -0500 Subject: [PATCH 24/36] fix: Apply uninitialized fallback to static POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 89 ++++++++++--------- .../static_unwritten_nested_pod_array.llzk | 50 +++++++++++ 2 files changed, 98 insertions(+), 41 deletions(-) create mode 100644 test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index e57d9d63e..0cc1eb5cc 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -741,6 +741,50 @@ static bool tryCollectDirectSplitPodArrayLeafValues( return false; } +/// Return whether the read is preceded by a write to the same pod record within its block. +static bool hasEarlierWriteInBlock(ReadPodOp readOp) { + Value podRef = readOp.getPodRef(); + StringAttr recordName = readOp.getRecordNameAttr(); + + for (Operation &op : *readOp->getBlock()) { + if (&op == readOp.getOperation()) { + return false; + } + + if (auto writeOp = dyn_cast(&op)) { + if (isSamePodRecord(writeOp, podRef, recordName)) { + return true; + } + } else if (hasNestedWriteToRecord(op, podRef, recordName)) { + return true; + } + } + return false; +} + +/// Return `true` iff `readOp` names a fresh pod record that has not been initialized or written. +static bool isFreshUnwrittenPodRead(ReadPodOp readOp) { + NewPodOp newPod = readOp.getPodRef().getDefiningOp(); + if (!newPod) { + return false; + } + auto isReadOpRecordName = [&readOp](Attribute attr) { + return attr == readOp.getRecordNameAttr(); + }; + return llvm::none_of(newPod.getInitializedRecords(), isReadOpRecordName) && + !hasEarlierWriteInBlock(readOp); +} + +/// Return `true` iff `value` is an unwritten array-of-POD field read from a fresh `pod.new`. +static bool isFreshUnwrittenPodArrayRead(Value value) { + while (auto cast = value.getDefiningOp()) { + value = cast.getInput(); + } + + ReadPodOp readOp = value.getDefiningOp(); + return readOp && splittablePodArray(readOp.getType()) && isFreshUnwrittenPodRead(readOp); +} + /// Read one flattened POD leaf, including leaves that live inside an array-of-POD record. static Value genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef recordChain) { @@ -757,6 +801,10 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef(getFlattenedTypeAlongPath(valueType, recordChain)); + if (isFreshUnwrittenPodArrayRead(value)) { + return createWritableArrayValue(bldr, loc, splitArrTy); + } + if (!arrTy.hasStaticShape()) { SmallVector splitIds; SmallVector splitTypes; @@ -769,13 +817,6 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef()) { - if (llvm::isa(readOp.getPodRef().getDefiningOp()) && - !findNearestForwardableWriteInBlock(readOp)) { - return createWritableArrayValue(bldr, loc, splitArrTy); - } - } - llvm_unreachable( "non-static nested array-of-POD scalarization requires split-array backing or an " "uninitialized pod field" @@ -927,40 +968,6 @@ static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &vir return llvm::isa(recType) || !splittablePodArray(recType); } -/// Return whether the read is preceded by a write to the same pod record within its block. -static bool hasEarlierWriteInBlock(ReadPodOp readOp) { - Value podRef = readOp.getPodRef(); - StringAttr recordName = readOp.getRecordNameAttr(); - - for (Operation &op : *readOp->getBlock()) { - if (&op == readOp.getOperation()) { - return false; - } - - if (auto writeOp = dyn_cast(&op)) { - if (isSamePodRecord(writeOp, podRef, recordName)) { - return true; - } - } else if (hasNestedWriteToRecord(op, podRef, recordName)) { - return true; - } - } - return false; -} - -/// Return `true` iff `readOp` names a fresh pod record that has not been initialized or written. -static bool isFreshUnwrittenPodRead(ReadPodOp readOp) { - NewPodOp newPod = readOp.getPodRef().getDefiningOp(); - if (!newPod) { - return false; - } - auto isReadOpRecordName = [&readOp](Attribute attr) { - return attr == readOp.getRecordNameAttr(); - }; - return llvm::none_of(newPod.getInitializedRecords(), isReadOpRecordName) && - !hasEarlierWriteInBlock(readOp); -} - /// Return `true` iff step 2 should defer splitting this array read until POD-aware rewriting. static bool shouldDeferPodArrayReadToStep3(ReadArrayOp op) { return splittablePodArray(op.getArrRefType()) && diff --git a/test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk b/test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk new file mode 100644 index 000000000..02a41875d --- /dev/null +++ b/test/Transforms/PodToScalar/static_unwritten_nested_pod_array.llzk @@ -0,0 +1,50 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<2 x !Item>]> +module attributes {llzk.lang} { + function.def @return_unwritten_static() -> !Outer { + %p = pod.new : !Outer + function.return %p : !Outer + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @return_unwritten_static() -> !array.type<2 x index> { +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: function.return %[[ARR]] : !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<2 x !Item>]> +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !Outer + + function.def @compute() -> !struct.type<@S> { + %p = pod.new : !Outer + %self = struct.new : !struct.type<@S> + struct.writem %self[@m] = %p : !struct.type<@S>, !Outer + function.return %self : !struct.type<@S> + } + + function.def @constrain(%self: !struct.type<@S>) { + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-NEXT: struct.member @m_items_x : !array.type<2 x index> +// CHECK-NEXT: function.def @compute() -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[SELF:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: struct.writem %[[SELF]][@m_items_x] = %[[ARR]] : <@S>, !array.type<2 x index> +// CHECK-NEXT: function.return %[[SELF]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[SELFARG:[0-9a-zA-Z_\.]+]]: !struct.type<@S>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From 3d0e44365c49652b4c6e67b96c76a9b2e2004e51 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 21:35:06 -0500 Subject: [PATCH 25/36] fix: Avoid aborting on dynamic nested POD operands --- .../POD/Transforms/PodToScalarPass.cpp | 315 ++++++++++++++---- ...write_direct_pod_dynamic_nested_array.llzk | 26 ++ 2 files changed, 285 insertions(+), 56 deletions(-) create mode 100644 test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 0cc1eb5cc..eb3a49caf 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -702,6 +702,14 @@ genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index, Va genArrayWrite(bldr, loc, arrayRef, indices, value); } +/// Strip compatibility casts introduced while threading POD-derived array values through rewrites. +static Value peelUnifiableCasts(Value value) { + while (auto cast = value.getDefiningOp()) { + value = cast.getInput(); + } + return value; +} + /// Collect split leaf arrays that are already available for an aggregate array-of-POD value. /// /// This peels compatibility casts and forwards through a dominating same-record `pod.write` so @@ -710,9 +718,7 @@ genArrayWrite(OpBuilder &bldr, Location loc, Value arrayRef, ArrayAttr index, Va static bool tryCollectDirectSplitPodArrayLeafValues( Value arrayValue, ArrayType arrTy, ArrayRef splitTypes, SmallVectorImpl &leafArrays ) { - while (auto cast = arrayValue.getDefiningOp()) { - arrayValue = cast.getInput(); - } + arrayValue = peelUnifiableCasts(arrayValue); if (auto cast = arrayValue.getDefiningOp()) { if (cast->getNumResults() != 1 || cast.getResult(0).getType() != arrTy || @@ -777,10 +783,7 @@ static bool isFreshUnwrittenPodRead(ReadPodOp readOp) { /// Return `true` iff `value` is an unwritten array-of-POD field read from a fresh `pod.new`. static bool isFreshUnwrittenPodArrayRead(Value value) { - while (auto cast = value.getDefiningOp()) { - value = cast.getInput(); - } - + value = peelUnifiableCasts(value); ReadPodOp readOp = value.getDefiningOp(); return readOp && splittablePodArray(readOp.getType()) && isFreshUnwrittenPodRead(readOp); } @@ -817,6 +820,15 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef()) { + auto splitLeafReads = + bldr.create(loc, TypeRange(splitTypes), strippedValue); + auto it = llvm::find(splitIds, RecordChain(recordChain)); + assert(it != splitIds.end() && "record path must name a flattened POD array leaf"); + return splitLeafReads.getResult(std::distance(splitIds.begin(), it)); + } + llvm_unreachable( "non-static nested array-of-POD scalarization requires split-array backing or an " "uninitialized pod field" @@ -944,6 +956,24 @@ static SmallVector orderedVirtualPodLeafValues( return orderedValues; } +/// Create a POD-typed placeholder for virtual leaf storage tracked in `leafValues`. +/// +/// PODs that embed affine-map-parameterized arrays cannot always be represented by a bare +/// `pod.new` at this stage because there may be no op-local instantiation operands available. +/// Use an unrealized cast from the ordered leaf values for those cases; later rewrites consult +/// `virtualPods` directly, and only concrete `pod.new` placeholders require materialization. +static Value createVirtualPodPlaceholder( + OpBuilder &bldr, Location loc, PodType podTy, const VirtualPodLeafMap &leafValues +) { + if (!hasAffineMapAttr(podTy)) { + return bldr.create(loc, podTy); + } + + SmallVector orderedValues = orderedVirtualPodLeafValues(podTy, loc, bldr, leafValues); + return bldr.create(loc, TypeRange {podTy}, orderedValues) + .getResult(0); +} + /// Materialize the tracked contents of a virtual POD into concrete `pod.write` operations. inline static void materializeVirtualPod(OpBuilder &bldr, NewPodOp pod, const VirtualPodLeafMap &leafValues) { @@ -1971,6 +2001,7 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM ConversionTarget target(*ctx); baseTargetSetup(target); + target.addLegalOp(); target.addDynamicallyLegalOp(SplitPodArrayNonDetOp::legal); target.addDynamicallyLegalOp(SplitPodArrayCreateArrayOp::legal); target.addDynamicallyLegalOp(SplitPodArrayReadArrayOp::legal); @@ -2162,19 +2193,21 @@ class SplitPodInFuncDefOp : public OpConversionPattern { Value oldV = entryBlock.getArgument(i); if (PodType pt = splittablePod(oldV.getType())) { Location loc = oldV.getLoc(); - auto newPod = rewriter.create(loc, pt); - rewriter.replaceAllUsesWith(oldV, newPod); - // Remove the argument from the block - entryBlock.eraseArgument(i); - - DenseMap leafValues; + VirtualPodLeafMap leafValues; SmallVector recordChain; + unsigned nextArgIdx = i + 1; forEachPodLeaf(pt, recordChain, [&](const RecordChain &id, Type leafType) { - BlockArgument newArg = entryBlock.insertArgument(i, leafType, loc); + BlockArgument newArg = entryBlock.insertArgument(nextArgIdx, leafType, loc); leafValues[id] = newArg; - ++i; + ++nextArgIdx; }); - virtualPods[newPod] = std::move(leafValues); + + Value virtualPod = createVirtualPodPlaceholder(rewriter, loc, pt, leafValues); + rewriter.replaceAllUsesWith(oldV, virtualPod); + entryBlock.eraseArgument(i); + + i += leafValues.size(); + virtualPods[virtualPod] = std::move(leafValues); } else { ++i; } @@ -2240,15 +2273,15 @@ static CallOp newCallOpWithSplitResults( for (Value oldVal : oldResults) { if (PodType pt = splittablePod(oldVal.getType())) { Location loc = oldVal.getLoc(); - DenseMap leafValues; + VirtualPodLeafMap leafValues; SmallVector recordChain; forEachPodLeaf(pt, recordChain, [&leafValues, &newResults](const RecordChain &id, Type) { leafValues[id] = *newResults; ++newResults; }); - NewPodOp newPod = rewriter.create(loc, pt); - virtualPods[newPod] = std::move(leafValues); - rewriter.replaceAllUsesWith(oldVal, newPod); + Value virtualPod = createVirtualPodPlaceholder(rewriter, loc, pt, leafValues); + virtualPods[virtualPod] = std::move(leafValues); + rewriter.replaceAllUsesWith(oldVal, virtualPod); } else { rewriter.replaceAllUsesWith(oldVal, *newResults); newResults++; @@ -2366,9 +2399,10 @@ class SplitPodInMemberReadOp : public OpConversionPattern { ); } - NewPodOp pod = rewriter.create(op.getLoc(), llvm::cast(op.getType())); - virtualPods[pod] = std::move(leafValues); - rewriter.replaceOp(op, pod); + PodType podTy = llvm::cast(op.getType()); + Value virtualPod = createVirtualPodPlaceholder(rewriter, op.getLoc(), podTy, leafValues); + virtualPods[virtualPod] = std::move(leafValues); + rewriter.replaceOp(op, virtualPod); } }; @@ -2428,6 +2462,88 @@ static bool tryCollectReadPodSplitPodArrayLeafValues( return false; } +/// Materialize or recover split leaf arrays for a dynamic array-of-POD produced by `pod.read`. +static bool resolveReadPodSplitPodArrayLeafValues( + ReadPodOp readOp, ArrayType arrTy, ArrayRef splitIds, ArrayRef splitTypes, + const VirtualPodValueMap &virtualPods, DeferredPodArrayLeafMap &deferredPodArrays, Location loc, + OpBuilder &bldr, SmallVectorImpl &leafArrays +) { + if (tryCollectReadPodSplitPodArrayLeafValues( + readOp, arrTy, splitIds, splitTypes, virtualPods, leafArrays + )) { + return true; + } + + if (!isFreshUnwrittenPodRead(readOp)) { + return false; + } + + // Reuse one synthetic split-array backing per deferred field read so repeated users of the same + // aggregate value continue to observe the same unwritten leaf storage. + auto [it, inserted] = deferredPodArrays.try_emplace(readOp.getResult()); + leafArrays.assign(it->second.begin(), it->second.end()); + if (inserted) { + OpBuilder::InsertionGuard guard(bldr); + bldr.setInsertionPointAfter(readOp); + leafArrays.reserve(splitTypes.size()); + for (Type splitType : splitTypes) { + leafArrays.push_back(createWritableArrayValue(bldr, loc, llvm::cast(splitType))); + } + it->second.assign(leafArrays.begin(), leafArrays.end()); + } else { + assert( + leafArrays.size() == splitTypes.size() && + "cached split POD arrays must match the rewritten read arity" + ); + } + + return true; +} + +/// Erase a resolved deferred field-read chain once both the read and its placeholder pod vanish. +static void eraseDeadDeferredFieldReadChain(ReadPodOp readOp, PatternRewriter &rewriter) { + if (!readOp.getResult().use_empty()) { + return; + } + + Value podRef = readOp.getPodRef(); + rewriter.eraseOp(readOp); + if (podRef.use_empty()) { + if (auto cast = podRef.getDefiningOp()) { + if (cast->getNumResults() == 1 && cast.getResult(0) == podRef) { + rewriter.eraseOp(cast); + } + } + } +} + +/// Return `true` iff `op` is a deferred split placeholder for one array-of-POD aggregate value. +static bool getDeferredSplitPodArrayCastInfo( + UnrealizedConversionCastOp op, ArrayType &arrTy, SmallVector &splitIds, + SmallVectorImpl &splitTypes +) { + if (op->getNumOperands() != 1) { + return false; + } + + arrTy = splittablePodArray(op.getOperand(0).getType()); + if (!arrTy) { + return false; + } + + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (op->getNumResults() != splitTypes.size()) { + return false; + } + + for (auto [result, splitType] : llvm::zip_equal(op.getResults(), splitTypes)) { + if (result.getType() != splitType) { + return false; + } + } + return true; +} + /// Resolve deferred `array.read` from `pod.read`-produced array-of-POD values. /// /// When step 2 defers a read because the array-of-POD came from a POD record, this pattern @@ -2478,44 +2594,91 @@ class ResolvePodReadBackedArrayReadOp : public OpConversionPattern splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); SmallVector splitLeafArrays; - if (!tryCollectReadPodSplitPodArrayLeafValues( - fieldRead, arrTy, splitIds, splitTypes, virtualPods, splitLeafArrays + if (!resolveReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, deferredPodArrays, op.getLoc(), + rewriter, splitLeafArrays )) { - if (!isFreshUnwrittenPodRead(fieldRead)) { - return failure(); - } - - // Reuse one synthetic split-array backing per deferred field read so repeated element reads - // from the same aggregate value see the same unwritten leaf storage. - auto [it, inserted] = deferredPodArrays.try_emplace(fieldRead.getResult()); - splitLeafArrays.assign(it->second.begin(), it->second.end()); - if (inserted) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(fieldRead); - splitLeafArrays.reserve(splitTypes.size()); - for (Type splitType : splitTypes) { - splitLeafArrays.push_back( - createWritableArrayValue(rewriter, op.getLoc(), llvm::cast(splitType)) - ); - } - it->second = splitLeafArrays; - } else { - assert( - splitLeafArrays.size() == splitTypes.size() && - "cached split POD arrays must match the rewritten read arity" - ); - } + return failure(); } SmallVector indices(adaptor.getIndices().begin(), adaptor.getIndices().end()); - DenseMap leafValues; + VirtualPodLeafMap leafValues; for (auto [id, leafArray] : llvm::zip_equal(splitIds, splitLeafArrays)) { leafValues[id] = genArrayRead(rewriter, op.getLoc(), leafArray, indices); } - NewPodOp pod = rewriter.create(op.getLoc(), podTy); - virtualPods[pod] = std::move(leafValues); - rewriter.replaceOp(op, pod); + Value virtualPod = createVirtualPodPlaceholder(rewriter, op.getLoc(), podTy, leafValues); + virtualPods[virtualPod] = std::move(leafValues); + rewriter.replaceOp(op, virtualPod); + eraseDeadDeferredFieldReadChain(fieldRead, rewriter); + return success(); + } +}; + +/// Resolve deferred split-array placeholders created while flattening direct POD operands. +/// +/// Step 2 may need one specific split leaf array from a dynamic array-of-POD field before step 3 +/// has converted the surrounding POD value into virtual leaf storage. In that case +/// `genReadAlongPath` leaves behind a `builtin.unrealized_conversion_cast` from the aggregate field +/// read to all split leaf arrays, and this pattern resolves that placeholder once the backing leaf +/// arrays become available. +class ResolveDeferredSplitPodArrayCastOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + DeferredPodArrayLeafMap &deferredPodArrays; + +public: + ResolveDeferredSplitPodArrayCastOp( + MLIRContext *ctx, VirtualPodValueMap &virtualPodMap, + DeferredPodArrayLeafMap &deferredPodArrayMap + ) + : OpConversionPattern(ctx), virtualPods(virtualPodMap), + deferredPodArrays(deferredPodArrayMap) {} + + static bool canResolve(UnrealizedConversionCastOp op, const VirtualPodValueMap &virtualPods) { + ArrayType arrTy; + SmallVector splitIds; + SmallVector splitTypes; + if (!getDeferredSplitPodArrayCastInfo(op, arrTy, splitIds, splitTypes)) { + return false; + } + + ReadPodOp fieldRead = peelUnifiableCasts(op.getOperand(0)).getDefiningOp(); + if (!fieldRead) { + return false; + } + + SmallVector ignoredLeafArrays; + return tryCollectReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, ignoredLeafArrays + ) || + isFreshUnwrittenPodRead(fieldRead); + } + + LogicalResult matchAndRewrite( + UnrealizedConversionCastOp op, OpAdaptor, ConversionPatternRewriter &rewriter + ) const override { + ArrayType arrTy; + SmallVector splitIds; + SmallVector splitTypes; + if (!getDeferredSplitPodArrayCastInfo(op, arrTy, splitIds, splitTypes)) { + return failure(); + } + + ReadPodOp fieldRead = peelUnifiableCasts(op.getOperand(0)).getDefiningOp(); + if (!fieldRead) { + return failure(); + } + + SmallVector splitLeafArrays; + if (!resolveReadPodSplitPodArrayLeafValues( + fieldRead, arrTy, splitIds, splitTypes, virtualPods, deferredPodArrays, op.getLoc(), + rewriter, splitLeafArrays + )) { + return failure(); + } + + rewriter.replaceOp(op, splitLeafArrays); + eraseDeadDeferredFieldReadChain(fieldRead, rewriter); return success(); } }; @@ -2552,9 +2715,10 @@ class ResolveVirtualPodReadOp : public OpConversionPattern { llvm::append_range(fullChain, id.nameList); nestedLeafValues[id] = leafValues->at(RecordChain(fullChain)); }); - NewPodOp pod = rewriter.create(op.getLoc(), nestedPodTy); - virtualPods[pod] = std::move(nestedLeafValues); - rewriter.replaceOp(op, pod); + Value virtualPod = + createVirtualPodPlaceholder(rewriter, op.getLoc(), nestedPodTy, nestedLeafValues); + virtualPods[virtualPod] = std::move(nestedLeafValues); + rewriter.replaceOp(op, virtualPod); return success(); } @@ -2587,6 +2751,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM ctx, symTables, memberRepMap, virtualPods ); patterns.add(ctx, virtualPods, deferredPodArrays); + patterns.add(ctx, virtualPods, deferredPodArrays); patterns.add(ctx, virtualPods); ConversionTarget target(*ctx); @@ -2601,6 +2766,11 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp([&virtualPods](ReadArrayOp op) { return !ResolvePodReadBackedArrayReadOp::canResolve(op, virtualPods); }); + target.addDynamicallyLegalOp( + [&virtualPods](UnrealizedConversionCastOp op) { + return !ResolveDeferredSplitPodArrayCastOp::canResolve(op, virtualPods); + } + ); target.addDynamicallyLegalOp([&virtualPods](ReadPodOp op) { return !canResolveVirtualPodRead(op, virtualPods); }); @@ -2620,6 +2790,39 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM materializeVirtualPod(builder, newPod, leafValues); } } + + bool erasedDeadPlaceholderOps = false; + do { + SmallVector deadPlaceholderOps; + modOp->walk([&](Operation *op) { + if (auto readOp = llvm::dyn_cast(op)) { + if (readOp.getResult().use_empty()) { + deadPlaceholderOps.push_back(op); + } + return; + } + + if (auto castOp = llvm::dyn_cast(op)) { + if (llvm::all_of(castOp.getResults(), [](Value result) { return result.use_empty(); })) { + deadPlaceholderOps.push_back(op); + } + } + }); + for (Operation *op : deadPlaceholderOps) { + op->erase(); + } + erasedDeadPlaceholderOps = !deadPlaceholderOps.empty(); + } while (erasedDeadPlaceholderOps); + + SmallVector deadOps; + modOp->walk([&](Operation *op) { + if (op != modOp.getOperation() && isOpTriviallyDead(op)) { + deadOps.push_back(op); + } + }); + for (Operation *op : deadOps) { + op->erase(); + } return success(); } diff --git a/test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk b/test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk new file mode 100644 index 000000000..f8ca47829 --- /dev/null +++ b/test/Transforms/PodToScalar/array_write_direct_pod_dynamic_nested_array.llzk @@ -0,0 +1,26 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#inner = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index]> +!Outer = !pod.type<[@items: !array.type<#inner x !Item>, @tag: index]> +!OuterArray = !array.type<2 x !Outer> +module attributes {llzk.lang} { + function.def @write_elem(%arr: !OuterArray, %elem: !Outer, %i: index) -> !OuterArray { + array.write %arr[%i] = %elem : !OuterArray, !Outer + function.return %arr : !OuterArray + } +} +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @write_elem( +// CHECK-SAME: %[[ARR_ITEMS:[0-9a-zA-Z_\.]+]]: !array.type<2,#[[$MAP]] x index>, +// CHECK-SAME: %[[ARR_TAG:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[ELEM_ITEMS:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, +// CHECK-SAME: %[[ELEM_TAG:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[IDX:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> (!array.type<2,#[[$MAP]] x index>, !array.type<2 x index>) { +// CHECK-NEXT: array.insert %[[ARR_ITEMS]]{{\[}}%[[IDX]]] = %[[ELEM_ITEMS]] : <2,#[[$MAP]] x index>, <#[[$MAP]] x index> +// CHECK-NEXT: array.write %[[ARR_TAG]]{{\[}}%[[IDX]]] = %[[ELEM_TAG]] : <2 x index>, index +// CHECK-NEXT: function.return %[[ARR_ITEMS]], %[[ARR_TAG]] : !array.type<2,#[[$MAP]] x index>, !array.type<2 x index> +// CHECK-NEXT: } +// CHECK-NEXT: } From 99563416f5fa156d52ee784517ccbb5bf55bff61 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 21:50:00 -0500 Subject: [PATCH 26/36] fix: Split pod.read-backed array returns --- .../POD/Transforms/PodToScalarPass.cpp | 54 ++++++++++++++++++- .../return_array_from_pod_field.llzk | 20 +++++++ 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 test/Transforms/PodToScalar/return_array_from_pod_field.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index eb3a49caf..8f3e5d0ac 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1576,10 +1576,62 @@ class SplitPodArrayInReturnOp : public OpConversionPattern { if (legal(op)) { return failure(); } - SmallVector newOperands = flattenConvertedValues(adaptor.getOperands()); + SmallVector newOperands; + for (auto [operand, convertedValues] : + llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { + collectSplitPodArrayOperandValues( + op.getLoc(), operand, convertedValues, newOperands, rewriter + ); + } rewriter.replaceOpWithNewOp(op, ValueRange(newOperands)); return success(); } + + /// Append the split leaf-array values for one step-2 operand. + /// + /// When dialect conversion has already produced the parallel leaf arrays, reuse those converted + /// values directly. Otherwise derive the split arrays from the original aggregate operand so uses + /// like `function.return` can still flatten a raw `pod.read` of an array field. + static void collectSplitPodArrayOperandValues( + Location loc, Value originalOperand, ValueRange convertedValues, + SmallVectorImpl &newOperands, ConversionPatternRewriter &rewriter + ) { + ArrayType arrTy = splittablePodArray(originalOperand.getType()); + if (!arrTy) { + llvm::append_range(newOperands, convertedValues); + return; + } + + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + auto isDirectAggregateToSplitCast = [&convertedValues, &originalOperand]() { + if (convertedValues.empty()) { + return false; + } + auto castOp = convertedValues.front().getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1 || castOp.getOperand(0) != originalOperand) { + return false; + } + return llvm::all_of(convertedValues, [&castOp](Value value) { + return value.getDefiningOp() == castOp; + }); + }; + + if (!isDirectAggregateToSplitCast() && convertedValues.size() == splitTypes.size() && + llvm::all_of(llvm::zip_equal(convertedValues, splitTypes), [](auto pair) { + return typesUnify(std::get<0>(pair).getType(), std::get<1>(pair)); + })) { + llvm::append_range(newOperands, convertedValues); + return; + } + + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + Value splitValue = genReadAlongPath(rewriter, loc, originalOperand, id); + newOperands.push_back(castValueToTypeIfNeeded(rewriter, loc, splitValue, splitType)); + } + } }; /// Rewrite calls whose arguments or results contain arrays-of-POD to use the split signature. diff --git a/test/Transforms/PodToScalar/return_array_from_pod_field.llzk b/test/Transforms/PodToScalar/return_array_from_pod_field.llzk new file mode 100644 index 000000000..4e1a36407 --- /dev/null +++ b/test/Transforms/PodToScalar/return_array_from_pod_field.llzk @@ -0,0 +1,20 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Item = !pod.type<[@x: index, @y: !felt.type]> +!Outer = !pod.type<[@items: !array.type<#outer x !Item>]> +module attributes {llzk.lang} { + function.def @get_items(%p: !Outer) -> !array.type<#outer x !Item> { + %items = pod.read %p[@items] : !Outer, !array.type<#outer x !Item> + function.return %items : !array.type<#outer x !Item> + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @get_items( +// CHECK-SAME: %[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, +// CHECK-SAME: %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type> +// CHECK-SAME: ) -> (!array.type<#[[$M]] x index>, !array.type<#[[$M]] x !felt.type>) { +// CHECK-NEXT: function.return %[[VAL_0]], %[[VAL_1]] : !array.type<#[[$M]] x index>, !array.type<#[[$M]] x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: } From ce0ce8f7ae704c87100f6ddb92a7d4debd180533 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Wed, 1 Jul 2026 22:06:47 -0500 Subject: [PATCH 27/36] fix: Rewrite quantifiers over POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 120 +++++++++++++++++- .../PodToScalar/bool_quantifiers.llzk | 79 ++++++++++++ 2 files changed, 195 insertions(+), 4 deletions(-) create mode 100644 test/Transforms/PodToScalar/bool_quantifiers.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 8f3e5d0ac..c70a3fd07 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -22,7 +22,7 @@ /// /// 2. Run a dialect conversion that splits arrays whose element type is a POD into parallel arrays /// in `llzk.nondet`, `array.*`, `constrain.eq`, `constrain.in`, `struct.readm`, `struct.writem`, -/// `function.def`, `function.call`, and `function.return`. +/// `function.def`, `function.call`, `function.return`, and bool quantifiers. /// /// 3. Run a dialect conversion that does the following: /// @@ -68,6 +68,7 @@ #include "llzk/Dialect/Array/IR/Ops.h" #include "llzk/Dialect/Array/IR/Types.h" #include "llzk/Dialect/Bool/IR/Dialect.h" +#include "llzk/Dialect/Bool/IR/Ops.h" #include "llzk/Dialect/Cast/IR/Dialect.h" #include "llzk/Dialect/Constrain/IR/Dialect.h" #include "llzk/Dialect/Constrain/IR/Ops.h" @@ -1883,6 +1884,116 @@ class SplitPodArrayLengthOp : public OpConversionPattern { } }; +/// Rebuild the current quantifier iterand from one read or extract per split POD-array leaf. +static Value rebuildSplitPodArrayQuantifierIterValue( + OpBuilder &bldr, Location loc, Type iterType, Value index, ArrayType sortType, + ValueRange convertedSort +) { + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(sortType, splitTypes, &splitIds); + assert( + convertedSort.size() == splitIds.size() && + "converted quantifier sort must provide one value per POD-array leaf" + ); + + DenseMap leafValues; + for (auto [id, leafArray] : llvm::zip_equal(splitIds, convertedSort)) { + SmallVector indices {index}; + leafValues[id] = genArrayRead(bldr, loc, leafArray, indices); + } + + SmallVector recordChain; + return rebuildFlattenedPodRecord(bldr, loc, iterType, recordChain, leafValues); +} + +/// Lower a bool quantifier over an array-of-POD to an `scf.for` over the split leaf arrays. +template +static LogicalResult rewriteSplitPodArrayQuantifier( + QuantifierOp op, ValueRange convertedSort, ConversionPatternRewriter &rewriter, + bool initialValue +) { + ArrayType sortType = llvm::cast(op.getSort().getType()); + Location loc = op.getLoc(); + + Value shapeCarrier = convertedSort.empty() + ? materializeArrayLengthCarrier(op.getSort(), sortType, loc, rewriter) + : convertedSort.front(); + Value lowerBound = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value upperBound = rewriter.create(loc, shapeCarrier, lowerBound); + Value step = rewriter.create(loc, rewriter.getIndexAttr(1)); + Value init = rewriter.create( + loc, IntegerAttr::get(IntegerType::get(rewriter.getContext(), 1), initialValue ? 1 : 0) + ); + + auto loop = rewriter.create(loc, lowerBound, upperBound, step, ValueRange {init}); + loop->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + + Block &loopBody = *loop.getBody(); + if (!loopBody.empty()) { + rewriter.eraseOp(&loopBody.back()); + } + + rewriter.setInsertionPointToStart(&loopBody); + Value iterValue = rebuildSplitPodArrayQuantifierIterValue( + rewriter, loc, op.getBody()->getArgument(0).getType(), loop.getInductionVar(), sortType, + convertedSort + ); + + IRMapping mapping; + mapping.map(op.getBody()->getArgument(0), iterValue); + + for (Operation &nestedOp : op.getBody()->without_terminator()) { + rewriter.clone(nestedOp, mapping); + } + + auto yieldOp = llvm::cast(op.getBody()->getTerminator()); + Value predicate = mapping.lookupOrDefault(yieldOp.getValue()); + Value combined = rewriter.create(loc, loop.getRegionIterArg(0), predicate); + rewriter.create(loc, combined); + + rewriter.replaceOp(op, loop.getResults()); + return success(); +} + +/// Rewrite `bool.forall` over an array-of-POD to iterate over the split leaf arrays directly. +class SplitPodArrayForAllOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(boolean::ForAllOp op) { return !splittablePodArray(op.getSort().getType()); } + + LogicalResult matchAndRewrite( + boolean::ForAllOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + return rewriteSplitPodArrayQuantifier( + op, adaptor.getSort(), rewriter, /*initialValue=*/true + ); + } +}; + +/// Rewrite `bool.exists` over an array-of-POD to iterate over the split leaf arrays directly. +class SplitPodArrayExistsOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(boolean::ExistsOp op) { return !splittablePodArray(op.getSort().getType()); } + + LogicalResult matchAndRewrite( + boolean::ExistsOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + return rewriteSplitPodArrayQuantifier( + op, adaptor.getSort(), rewriter, /*initialValue=*/false + ); + } +}; + /// Rewrite `array.extract` of an array-of-POD subarray into one extract per parallel leaf array. class SplitPodArrayExtractArrayOp : public OpConversionPattern { public: @@ -2044,9 +2155,8 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM SplitPodArrayNonDetOp, SplitPodArrayCreateArrayOp, SplitPodArrayReadArrayOp, SplitPodArrayWriteArrayOp, SplitPodArrayExtractArrayOp, SplitPodArrayInsertArrayOp, SplitPodArrayInFuncDefOp, SplitPodArrayInReturnOp, SplitPodArrayInCallOp, - SplitPodArrayInEmitEqualityOp, SplitPodArrayInEmitContainmentOp, SplitPodArrayLengthOp>( - typeConverter, ctx - ); + SplitPodArrayInEmitEqualityOp, SplitPodArrayInEmitContainmentOp, SplitPodArrayLengthOp, + SplitPodArrayForAllOp, SplitPodArrayExistsOp>(typeConverter, ctx); patterns.add( typeConverter, ctx, symTables, memberRepMap ); @@ -2068,6 +2178,8 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM SplitPodArrayInEmitContainmentOp::legal ); target.addDynamicallyLegalOp(SplitPodArrayLengthOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayForAllOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayExistsOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInMemberReadOp::legal); diff --git a/test/Transforms/PodToScalar/bool_quantifiers.llzk b/test/Transforms/PodToScalar/bool_quantifiers.llzk new file mode 100644 index 000000000..126f38c5c --- /dev/null +++ b/test/Transforms/PodToScalar/bool_quantifiers.llzk @@ -0,0 +1,79 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Pair = !pod.type<[@x: index, @y: index]> +module attributes {llzk.lang} { + function.def @forall_arg(%arr: !array.type<2 x !Pair>, %limit: index) -> i1 attributes {function.allow_non_native_field_ops} { + %all = bool.forall %elt in %arr : !array.type<2 x !Pair> { + %x = pod.read %elt[@x] : !Pair, index + %ok = arith.cmpi slt, %x, %limit : index + bool.yield %ok + } + function.return %all : i1 + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @forall_arg( +// CHECK-SAME: %[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, +// CHECK-SAME: %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type<2 x index>, %[[LIMIT:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> i1 attributes {function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[LB:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[UB:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[LB]] : <2 x index> +// CHECK-NEXT: %[[STEP:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[INIT:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: %[[RES:[0-9a-zA-Z_\.]+]] = scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ACC:[0-9a-zA-Z_\.]+]] = %[[INIT]]) -> (i1) { +// CHECK-DAG: %[[X:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_X]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[Y:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_Y]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[OK:[0-9a-zA-Z_\.]+]] = arith.cmpi slt, %[[X]], %[[LIMIT]] : index +// CHECK-NEXT: %[[COMBINED:[0-9a-zA-Z_\.]+]] = bool.and %[[ACC]], %[[OK]] : i1, i1 +// CHECK-NEXT: scf.yield %[[COMBINED]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: function.return %[[RES]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Pair = !pod.type<[@x: index, @y: index]> +module attributes {llzk.lang} { + function.def @exists_array_new(%x0: index, %y0: index, %x1: index, %y1: index) -> i1 attributes {function.allow_non_native_field_ops} { + %p0 = pod.new { @x = %x0, @y = %y0 } : !Pair + %p1 = pod.new { @x = %x1, @y = %y1 } : !Pair + %arr = array.new %p0, %p1 : !array.type<2 x !Pair> + %any = bool.exists %elt in %arr : !array.type<2 x !Pair> { + %y = pod.read %elt[@y] : !Pair, index + %ok = arith.cmpi eq, %y, %y1 : index + bool.yield %ok + } + function.return %any : i1 + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @exists_array_new( +// CHECK-SAME: %[[X0:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[Y0:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[X1:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[Y1:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) -> i1 attributes {function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[ARR_X:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: %[[C0X:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.write %[[ARR_X]][%[[C0X]]] = %[[X0]] : <2 x index>, index +// CHECK-NEXT: %[[C1X:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[ARR_X]][%[[C1X]]] = %[[X1]] : <2 x index>, index +// CHECK-NEXT: %[[ARR_Y:[0-9a-zA-Z_\.]+]] = array.new : <2 x index> +// CHECK-NEXT: %[[C0Y:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: array.write %[[ARR_Y]][%[[C0Y]]] = %[[Y0]] : <2 x index>, index +// CHECK-NEXT: %[[C1Y:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: array.write %[[ARR_Y]][%[[C1Y]]] = %[[Y1]] : <2 x index>, index +// CHECK-NEXT: %[[LB:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[UB:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[LB]] : <2 x index> +// CHECK-NEXT: %[[STEP:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[INIT:[0-9a-zA-Z_\.]+]] = arith.constant false +// CHECK-NEXT: %[[RES:[0-9a-zA-Z_\.]+]] = scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ACC:[0-9a-zA-Z_\.]+]] = %[[INIT]]) -> (i1) { +// CHECK-DAG: %[[X:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_X]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[Y:[0-9a-zA-Z_\.]+]] = array.read %[[ARR_Y]][%[[IV]]] : <2 x index>, index +// CHECK-DAG: %[[OK:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[Y]], %[[Y1]] : index +// CHECK-NEXT: %[[COMBINED:[0-9a-zA-Z_\.]+]] = bool.or %[[ACC]], %[[OK]] : i1, i1 +// CHECK-NEXT: scf.yield %[[COMBINED]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: function.return %[[RES]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: } From ac379dd4f75913f86ba9a6e7a2149e3020eb3e91 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 08:08:29 -0500 Subject: [PATCH 28/36] clang-tidy --- lib/Dialect/POD/Transforms/PodToScalarPass.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index c70a3fd07..19de7b3e4 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -816,7 +816,7 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef leafArrays; if (tryCollectDirectSplitPodArrayLeafValues(value, arrTy, splitTypes, leafArrays)) { - auto it = llvm::find(splitIds, RecordChain(recordChain)); + auto *it = llvm::find(splitIds, RecordChain(recordChain)); assert(it != splitIds.end() && "record path must name a flattened POD array leaf"); return leafArrays[std::distance(splitIds.begin(), it)]; } @@ -825,7 +825,7 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef()) { auto splitLeafReads = bldr.create(loc, TypeRange(splitTypes), strippedValue); - auto it = llvm::find(splitIds, RecordChain(recordChain)); + auto *it = llvm::find(splitIds, RecordChain(recordChain)); assert(it != splitIds.end() && "record path must name a flattened POD array leaf"); return splitLeafReads.getResult(std::distance(splitIds.begin(), it)); } From 3b352b39bad4ef7961e6bf5e3c6e3ba702c2761f Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 09:34:39 -0500 Subject: [PATCH 29/36] fix: Honor writes before resolving virtual POD reads --- .../POD/Transforms/PodToScalarPass.cpp | 176 +++++++++++++++--- .../virtual_pod_write_then_read.llzk | 82 ++++++++ 2 files changed, 227 insertions(+), 31 deletions(-) create mode 100644 test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 19de7b3e4..2548466b3 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -176,6 +176,13 @@ static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recor }); } +/// Return whether `op` contains a nested write to any record of `podRef`. +static bool hasNestedWriteToPod(Operation &op, Value podRef) { + return walkContainsMatch(op, [&](WritePodOp writeOp) { + return writeOp.getOperation() != &op && writeOp.getPodRef() == podRef; + }); +} + /// Return whether `op` contains any read from `podRef.recordName`. static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) { return walkContainsMatch(op, [&podRef, &recordName](ReadPodOp readOp) { @@ -748,21 +755,41 @@ static bool tryCollectDirectSplitPodArrayLeafValues( return false; } +/// Return whether `op` is preceded in its block by a write to `podRef.recordName`. +static bool hasEarlierWriteToRecordInBlock(Operation *op, Value podRef, StringAttr recordName) { + for (Operation &candidate : *op->getBlock()) { + if (&candidate == op) { + return false; + } + if (auto writeOp = dyn_cast(&candidate)) { + if (isSamePodRecord(writeOp, podRef, recordName)) { + return true; + } + } else if (hasNestedWriteToRecord(candidate, podRef, recordName)) { + return true; + } + } + return false; +} + /// Return whether the read is preceded by a write to the same pod record within its block. static bool hasEarlierWriteInBlock(ReadPodOp readOp) { - Value podRef = readOp.getPodRef(); - StringAttr recordName = readOp.getRecordNameAttr(); + return hasEarlierWriteToRecordInBlock( + readOp.getOperation(), readOp.getPodRef(), readOp.getRecordNameAttr() + ); +} - for (Operation &op : *readOp->getBlock()) { - if (&op == readOp.getOperation()) { +/// Return whether `op` is preceded in its block by any write to `podRef`. +static bool hasEarlierWriteToPodInBlock(Operation *op, Value podRef) { + for (Operation &candidate : *op->getBlock()) { + if (&candidate == op) { return false; } - - if (auto writeOp = dyn_cast(&op)) { - if (isSamePodRecord(writeOp, podRef, recordName)) { + if (auto writeOp = dyn_cast(&candidate)) { + if (writeOp.getPodRef() == podRef) { return true; } - } else if (hasNestedWriteToRecord(op, podRef, recordName)) { + } else if (hasNestedWriteToPod(candidate, podRef)) { return true; } } @@ -992,7 +1019,8 @@ materializeVirtualPod(OpBuilder &bldr, NewPodOp pod, const VirtualPodLeafMap &le /// Return `true` iff a read from a virtual POD can be resolved without materializing it. static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &virtualPods) { - if (!lookupVirtualPodLeafMap(op.getPodRef(), virtualPods)) { + if (!lookupVirtualPodLeafMap(op.getPodRef(), virtualPods) || hasEarlierWriteInBlock(op) || + findNearestForwardableWriteInBlock(op)) { return false; } Type recType = llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); @@ -1026,11 +1054,13 @@ static SmallVector getSplitRecordNameSuffixes(Type type) { // add the original operand to the list. static void processInputOperand( Location loc, Value operand, SmallVector &newOperands, - ConversionPatternRewriter &rewriter, const VirtualPodValueMap *virtualPods = nullptr + ConversionPatternRewriter &rewriter, Operation *userOp = nullptr, + const VirtualPodValueMap *virtualPods = nullptr ) { if (PodType pt = splittablePod(operand.getType())) { if (virtualPods) { - if (const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(operand, *virtualPods)) { + if (const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(operand, *virtualPods); + leafValues && (!userOp || !hasEarlierWriteToPodInBlock(userOp, operand))) { llvm::append_range( newOperands, orderedVirtualPodLeafValues(pt, loc, rewriter, *leafValues) ); @@ -1054,13 +1084,56 @@ static void processInputOperands( ) { SmallVector newOperands; for (Value v : operands) { - processInputOperand(op->getLoc(), v, newOperands, rewriter, virtualPods); + processInputOperand(op->getLoc(), v, newOperands, rewriter, op, virtualPods); } rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() { outputOpRef.assign(ValueRange(newOperands)); }); } +/// Update the tracked leaf values for one top-level POD record after a virtual `pod.write`. +static void updateVirtualPodRecordLeafValues( + Location loc, StringAttr recordName, Type recordType, Value recordValue, + const VirtualPodValueMap &virtualPods, ConversionPatternRewriter &rewriter, + VirtualPodLeafMap &leafValues +) { + SmallVector prefix {recordName}; + + if (PodType nestedPodTy = llvm::dyn_cast(recordType)) { + if (const VirtualPodLeafMap *nestedLeafValues = + lookupVirtualPodLeafMap(recordValue, virtualPods)) { + SmallVector nestedRecordChain; + forEachPodLeaf(nestedPodTy, nestedRecordChain, [&](const RecordChain &id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + leafValues[RecordChain(fullChain)] = nestedLeafValues->at(id); + }); + return; + } + + SmallVector nestedRecordChain; + forEachPodLeaf(nestedPodTy, nestedRecordChain, [&](const RecordChain &id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + leafValues[RecordChain(fullChain)] = genReadAlongPath(rewriter, loc, recordValue, id); + }); + return; + } + + if (ArrayType arrTy = splittablePodArray(recordType)) { + auto elemPodTy = llvm::cast(arrTy.getElementType()); + SmallVector nestedRecordChain; + forEachPodLeaf(elemPodTy, nestedRecordChain, [&](const RecordChain &id, Type) { + SmallVector fullChain(prefix); + llvm::append_range(fullChain, id.nameList); + leafValues[RecordChain(fullChain)] = genReadAlongPath(rewriter, loc, recordValue, id); + }); + return; + } + + leafValues[RecordChain(prefix)] = castValueToTypeIfNeeded(rewriter, loc, recordValue, recordType); +} + /// Register the dialects and operations that remain legal across the conversion-based stages. inline static void baseTargetSetup(ConversionTarget &target) { target.addLegalDialect< @@ -2245,7 +2318,9 @@ class SplitPodElementCreateArrayOp : public OpConversionPattern { for (Value element : adaptor.getElements()) { SmallVector flattenedValues; if (splittablePod(element.getType())) { - processInputOperand(op.getLoc(), element, flattenedValues, rewriter, &virtualPods); + processInputOperand( + op.getLoc(), element, flattenedValues, rewriter, op.getOperation(), &virtualPods + ); } else { flattenedValues.push_back(element); } @@ -2513,7 +2588,9 @@ class SplitPodInMemberWriteOp : public OpConversionPattern { const LocalMemberReplacementMap &idToMember = repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr()); const VirtualPodLeafMap *virtualLeafValues = - lookupVirtualPodLeafMap(adaptor.getVal(), virtualPods); + !hasEarlierWriteToPodInBlock(op.getOperation(), adaptor.getVal()) + ? lookupVirtualPodLeafMap(adaptor.getVal(), virtualPods) + : nullptr; for (const auto &[id, newMember] : idToMember) { Value scalarValue = virtualLeafValues @@ -2601,28 +2678,30 @@ static bool tryCollectReadPodSplitPodArrayLeafValues( ReadPodOp readOp, ArrayType arrTy, ArrayRef splitIds, ArrayRef splitTypes, const VirtualPodValueMap &virtualPods, SmallVectorImpl &leafArrays ) { - if (const VirtualPodLeafMap *podLeafValues = - lookupVirtualPodLeafMap(readOp.getPodRef(), virtualPods)) { - leafArrays.reserve(splitIds.size()); - for (const RecordChain &id : splitIds) { - SmallVector fullChain {readOp.getRecordNameAttr()}; - llvm::append_range(fullChain, id.nameList); - auto it = podLeafValues->find(RecordChain(fullChain)); - if (it == podLeafValues->end() || - !typesUnify(it->second.getType(), getFlattenedTypeAlongPath(arrTy, id.nameList))) { - return false; - } - leafArrays.push_back(it->second); - } - return true; - } - if (WritePodOp writeOp = findNearestForwardableWriteInBlock(readOp)) { return tryCollectMaterializedSplitPodArrayLeafValues( writeOp.getValue(), arrTy, splitTypes, leafArrays ); } + if (!hasEarlierWriteInBlock(readOp)) { + if (const VirtualPodLeafMap *podLeafValues = + lookupVirtualPodLeafMap(readOp.getPodRef(), virtualPods)) { + leafArrays.reserve(splitIds.size()); + for (const RecordChain &id : splitIds) { + SmallVector fullChain {readOp.getRecordNameAttr()}; + llvm::append_range(fullChain, id.nameList); + auto it = podLeafValues->find(RecordChain(fullChain)); + if (it == podLeafValues->end() || + !typesUnify(it->second.getType(), getFlattenedTypeAlongPath(arrTy, id.nameList))) { + return false; + } + leafArrays.push_back(it->second); + } + return true; + } + } + return false; } @@ -2847,6 +2926,34 @@ class ResolveDeferredSplitPodArrayCastOp : public OpConversionPattern { + VirtualPodValueMap &virtualPods; + +public: + ResolveVirtualPodWriteOp(MLIRContext *ctx, VirtualPodValueMap &virtualPodMap) + : OpConversionPattern(ctx), virtualPods(virtualPodMap) {} + + LogicalResult matchAndRewrite( + WritePodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + auto it = virtualPods.find(adaptor.getPodRef()); + if (it == virtualPods.end()) { + return failure(); + } + + Type recordType = + llvm::cast(op.getPodRefType()).getRecordMap().lookup(op.getRecordName()); + assert(recordType && "record must exist in POD type"); + updateVirtualPodRecordLeafValues( + op.getLoc(), op.getRecordNameAttr(), recordType, adaptor.getValue(), virtualPods, rewriter, + it->second + ); + rewriter.eraseOp(op); + return success(); + } +}; + /// Resolve reads from a virtual POD placeholder without materializing the whole aggregate. /// /// This pattern answers `pod.read` directly from virtual leaf storage, rebuilding nested POD @@ -2861,6 +2968,10 @@ class ResolveVirtualPodReadOp : public OpConversionPattern { LogicalResult matchAndRewrite( ReadPodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter ) const override { + if (hasEarlierWriteInBlock(op) || findNearestForwardableWriteInBlock(op)) { + return failure(); + } + const VirtualPodLeafMap *leafValues = lookupVirtualPodLeafMap(adaptor.getPodRef(), virtualPods); if (!leafValues) { return failure(); @@ -2916,7 +3027,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM ); patterns.add(ctx, virtualPods, deferredPodArrays); patterns.add(ctx, virtualPods, deferredPodArrays); - patterns.add(ctx, virtualPods); + patterns.add(ctx, virtualPods); ConversionTarget target(*ctx); baseTargetSetup(target); @@ -2927,6 +3038,9 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodInCallOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberWriteOp::legal); target.addDynamicallyLegalOp(SplitPodInMemberReadOp::legal); + target.addDynamicallyLegalOp([&virtualPods](WritePodOp op) { + return !lookupVirtualPodLeafMap(op.getPodRef(), virtualPods); + }); target.addDynamicallyLegalOp([&virtualPods](ReadArrayOp op) { return !ResolvePodReadBackedArrayReadOp::canResolve(op, virtualPods); }); diff --git a/test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk b/test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk new file mode 100644 index 000000000..57d8107a3 --- /dev/null +++ b/test/Transforms/PodToScalar/virtual_pod_write_then_read.llzk @@ -0,0 +1,82 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + function.def @arg_case(%p: !Single, %new: index) -> index { + pod.write %p[@x] = %new : !Single, index + %x = pod.read %p[@x] : !Single, index + function.return %x : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @arg_case(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: function.return %[[VAL_1]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + function.def @mk(%old: index) -> !Single { + %p = pod.new : !Single + pod.write %p[@x] = %old : !Single, index + function.return %p : !Single + } + + function.def @call_case(%old: index, %new: index) -> index { + %p = function.call @mk(%old) : (index) -> !Single + pod.write %p[@x] = %new : !Single, index + %x = pod.read %p[@x] : !Single, index + function.return %x : index + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @mk(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: function.return %[[VAL_0]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_case(%[[VAL_1:[0-9a-zA-Z_\.]+]]: index, %[[VAL_2:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = function.call @mk(%[[VAL_1]]) : (index) -> index +// CHECK-NEXT: function.return %[[VAL_2]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + struct.def @Box { + struct.member @p : !Single + struct.member @out : index + + function.def @compute(%old: index, %new: index) -> !struct.type<@Box> { + %self = struct.new : !struct.type<@Box> + %p = pod.new : !Single + pod.write %p[@x] = %old : !Single, index + struct.writem %self[@p] = %p : !struct.type<@Box>, !Single + %loaded = struct.readm %self[@p] : !struct.type<@Box>, !Single + pod.write %loaded[@x] = %new : !Single, index + %x = pod.read %loaded[@x] : !Single, index + struct.writem %self[@out] = %x : !struct.type<@Box>, index + function.return %self : !struct.type<@Box> + } + + function.def @constrain(%self: !struct.type<@Box>, %old: index, %new: index) { + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Box { +// CHECK-NEXT: struct.member @p_x : index +// CHECK-NEXT: struct.member @out : index +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index, %[[VAL_1:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@Box> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@Box> +// CHECK-NEXT: struct.writem %[[VAL_2]][@p_x] = %[[VAL_0]] : <@Box>, index +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@p_x] : <@Box>, index +// CHECK-NEXT: struct.writem %[[VAL_2]][@out] = %[[VAL_1]] : <@Box>, index +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@Box> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@Box>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, %[[VAL_6:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From 8ea12b1905161b82728e549966706a0e6fd3abb4 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 10:12:43 -0500 Subject: [PATCH 30/36] fix: Split unifiable casts for POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 159 ++++++++++++------ .../unifiable_cast_array_of_pod.llzk | 31 ++++ 2 files changed, 141 insertions(+), 49 deletions(-) create mode 100644 test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 2548466b3..338b077ba 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1635,6 +1635,112 @@ class SplitPodArrayInFuncDefOp : public OpConversionPattern { } }; +/// Append the split leaf-array values for one step-2 operand. +/// +/// When dialect conversion has already produced the parallel leaf arrays, reuse those converted +/// values directly. Otherwise derive the split arrays from the original aggregate operand so users +/// like `poly.unifiable_cast` and `function.return` can still flatten a raw `pod.read` of an array +/// field. +static void collectSplitPodArrayOperandValues( + Location loc, Value originalOperand, ValueRange convertedValues, + SmallVectorImpl &newOperands, ConversionPatternRewriter &rewriter +) { + ArrayType arrTy = splittablePodArray(originalOperand.getType()); + if (!arrTy) { + llvm::append_range(newOperands, convertedValues); + return; + } + + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + + auto isDirectAggregateToSplitCast = [&convertedValues, &originalOperand]() { + if (convertedValues.empty()) { + return false; + } + auto castOp = convertedValues.front().getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1 || castOp.getOperand(0) != originalOperand) { + return false; + } + return llvm::all_of(convertedValues, [&castOp](Value value) { + return value.getDefiningOp() == castOp; + }); + }; + + if (!isDirectAggregateToSplitCast() && convertedValues.size() == splitTypes.size() && + llvm::all_of(llvm::zip_equal(convertedValues, splitTypes), [](auto pair) { + return typesUnify(std::get<0>(pair).getType(), std::get<1>(pair)); + })) { + llvm::append_range(newOperands, convertedValues); + return; + } + + for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { + Value splitValue = genReadAlongPath(rewriter, loc, originalOperand, id); + newOperands.push_back(castValueToTypeIfNeeded(rewriter, loc, splitValue, splitType)); + } +} + +/// Rewrite array-of-POD `poly.unifiable_cast` into one leaf-array cast per split array. +class SplitPodArrayInUnifiableCastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool legal(UnifiableCastOp op) { return !splittablePodArray(op.getType()); } + + LogicalResult matchAndRewrite( + UnifiableCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (legal(op)) { + return failure(); + } + + ArrayType inputArrTy = splittablePodArray(op.getInput().getType()); + if (!inputArrTy) { + return rewriter.notifyMatchFailure( + op, "expected array-of-pod cast input when rewriting array-of-pod cast result" + ); + } + + SmallVector inputSplitIds; + SmallVector inputSplitTypes; + splitPodArrayTypeTo(inputArrTy, inputSplitTypes, &inputSplitIds); + + ArrayType resultArrTy = llvm::cast(op.getType()); + SmallVector resultSplitIds; + SmallVector resultSplitTypes; + splitPodArrayTypeTo(resultArrTy, resultSplitTypes, &resultSplitIds); + + if (inputSplitIds != resultSplitIds) { + return rewriter.notifyMatchFailure( + op, "array-of-pod cast changed POD leaf structure unexpectedly" + ); + } + + SmallVector splitInputs; + collectSplitPodArrayOperandValues( + op.getLoc(), op.getInput(), adaptor.getInput(), splitInputs, rewriter + ); + if (splitInputs.size() != resultSplitTypes.size()) { + return rewriter.notifyMatchFailure( + op, "failed to collect one split input per array-of-pod cast leaf" + ); + } + + SmallVector replacements; + replacements.reserve(resultSplitTypes.size()); + for (auto [splitInput, resultSplitType] : llvm::zip_equal(splitInputs, resultSplitTypes)) { + replacements.push_back( + castValueToTypeIfNeeded(rewriter, op.getLoc(), splitInput, resultSplitType) + ); + } + + rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); + return success(); + } +}; + /// Rewrite `function.return` to flatten any array-of-POD operands into their parallel arrays. class SplitPodArrayInReturnOp : public OpConversionPattern { public: @@ -1660,52 +1766,6 @@ class SplitPodArrayInReturnOp : public OpConversionPattern { rewriter.replaceOpWithNewOp(op, ValueRange(newOperands)); return success(); } - - /// Append the split leaf-array values for one step-2 operand. - /// - /// When dialect conversion has already produced the parallel leaf arrays, reuse those converted - /// values directly. Otherwise derive the split arrays from the original aggregate operand so uses - /// like `function.return` can still flatten a raw `pod.read` of an array field. - static void collectSplitPodArrayOperandValues( - Location loc, Value originalOperand, ValueRange convertedValues, - SmallVectorImpl &newOperands, ConversionPatternRewriter &rewriter - ) { - ArrayType arrTy = splittablePodArray(originalOperand.getType()); - if (!arrTy) { - llvm::append_range(newOperands, convertedValues); - return; - } - - SmallVector splitIds; - SmallVector splitTypes; - splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); - - auto isDirectAggregateToSplitCast = [&convertedValues, &originalOperand]() { - if (convertedValues.empty()) { - return false; - } - auto castOp = convertedValues.front().getDefiningOp(); - if (!castOp || castOp->getNumOperands() != 1 || castOp.getOperand(0) != originalOperand) { - return false; - } - return llvm::all_of(convertedValues, [&castOp](Value value) { - return value.getDefiningOp() == castOp; - }); - }; - - if (!isDirectAggregateToSplitCast() && convertedValues.size() == splitTypes.size() && - llvm::all_of(llvm::zip_equal(convertedValues, splitTypes), [](auto pair) { - return typesUnify(std::get<0>(pair).getType(), std::get<1>(pair)); - })) { - llvm::append_range(newOperands, convertedValues); - return; - } - - for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { - Value splitValue = genReadAlongPath(rewriter, loc, originalOperand, id); - newOperands.push_back(castValueToTypeIfNeeded(rewriter, loc, splitValue, splitType)); - } - } }; /// Rewrite calls whose arguments or results contain arrays-of-POD to use the split signature. @@ -2227,9 +2287,9 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM patterns.add< SplitPodArrayNonDetOp, SplitPodArrayCreateArrayOp, SplitPodArrayReadArrayOp, SplitPodArrayWriteArrayOp, SplitPodArrayExtractArrayOp, SplitPodArrayInsertArrayOp, - SplitPodArrayInFuncDefOp, SplitPodArrayInReturnOp, SplitPodArrayInCallOp, - SplitPodArrayInEmitEqualityOp, SplitPodArrayInEmitContainmentOp, SplitPodArrayLengthOp, - SplitPodArrayForAllOp, SplitPodArrayExistsOp>(typeConverter, ctx); + SplitPodArrayInFuncDefOp, SplitPodArrayInUnifiableCastOp, SplitPodArrayInReturnOp, + SplitPodArrayInCallOp, SplitPodArrayInEmitEqualityOp, SplitPodArrayInEmitContainmentOp, + SplitPodArrayLengthOp, SplitPodArrayForAllOp, SplitPodArrayExistsOp>(typeConverter, ctx); patterns.add( typeConverter, ctx, symTables, memberRepMap ); @@ -2244,6 +2304,7 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM target.addDynamicallyLegalOp(SplitPodArrayExtractArrayOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInsertArrayOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInFuncDefOp::legal); + target.addDynamicallyLegalOp(SplitPodArrayInUnifiableCastOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInReturnOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInCallOp::legal); target.addDynamicallyLegalOp(SplitPodArrayInEmitEqualityOp::legal); diff --git a/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk b/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk new file mode 100644 index 000000000..7f6cb3b10 --- /dev/null +++ b/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk @@ -0,0 +1,31 @@ +// RUN: llzk-opt -split-input-file -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#map = affine_map<()[s0] -> (s0)> +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + function.def @sink(%arr: !array.type) -> index { + %c0 = arith.constant 0 : index + %len = array.len %arr, %c0 : !array.type + function.return %len : index + } + + function.def @main(%arr: !array.type<#map x !Pair>) -> (!array.type, index) { + %cast = poly.unifiable_cast %arr : (!array.type<#map x !Pair>) -> !array.type + %len = function.call @sink(%cast) : (!array.type) -> index + function.return %cast, %len : !array.type, index + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @sink(%[[ARR_X:[0-9a-zA-Z_\.]+]]: !array.type, %[[ARR_Y:[0-9a-zA-Z_\.]+]]: !array.type) -> index { +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR_X]], %[[C0]] : +// CHECK-NEXT: function.return %[[LEN]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @main(%[[ARG_X:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[ARG_Y:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> (!array.type, !array.type, index) { +// CHECK-NEXT: %[[CAST_X:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[ARG_X]] : (!array.type<#[[$M]] x index>) -> !array.type +// CHECK-NEXT: %[[CAST_Y:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[ARG_Y]] : (!array.type<#[[$M]] x !felt.type>) -> !array.type +// CHECK-NEXT: %[[CALL_LEN:[0-9a-zA-Z_\.]+]] = function.call @sink(%[[CAST_X]], %[[CAST_Y]]) : (!array.type, !array.type) -> index +// CHECK-NEXT: function.return %[[CAST_X]], %[[CAST_Y]], %[[CALL_LEN]] : !array.type, !array.type, index +// CHECK-NEXT: } +// CHECK-NEXT: } From ddc518b9c9dd75431e272d43dce58dbfa803e59d Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 10:52:01 -0500 Subject: [PATCH 31/36] fix: Split POD-array call operands from POD reads --- .../POD/Transforms/PodToScalarPass.cpp | 22 ++++++++++---- .../call_with_pod_read_array_of_pod.llzk | 30 +++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 338b077ba..ed619030d 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1655,16 +1655,22 @@ static void collectSplitPodArrayOperandValues( SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); - auto isDirectAggregateToSplitCast = [&convertedValues, &originalOperand]() { + auto isDirectAggregateToSplitCast = [&convertedValues, &splitTypes]() { if (convertedValues.empty()) { return false; } auto castOp = convertedValues.front().getDefiningOp(); - if (!castOp || castOp->getNumOperands() != 1 || castOp.getOperand(0) != originalOperand) { + if (!castOp || castOp->getNumOperands() != 1 || + !splittablePodArray(castOp.getOperand(0).getType()) || + castOp->getNumResults() != splitTypes.size()) { return false; } - return llvm::all_of(convertedValues, [&castOp](Value value) { - return value.getDefiningOp() == castOp; + + return llvm::all_of(llvm::zip_equal(convertedValues, splitTypes), [&castOp](auto pair) { + Value convertedValue = std::get<0>(pair); + Type splitType = std::get<1>(pair); + return convertedValue.getDefiningOp() == castOp && + typesUnify(convertedValue.getType(), splitType); }); }; @@ -1802,7 +1808,13 @@ class SplitPodArrayInCallOp : public OpConversionPattern { mapOperands.push_back(values); } - SmallVector newArgOperands = flattenConvertedValues(adaptor.getArgOperands()); + SmallVector newArgOperands; + for (auto [operand, convertedValues] : + llvm::zip_equal(op.getArgOperands(), adaptor.getArgOperands())) { + collectSplitPodArrayOperandValues( + op.getLoc(), operand, convertedValues, newArgOperands, rewriter + ); + } CallOp newCall = createCallPreservingInstantiationOperands( op.getLoc(), newResultTypes, op, mapOperands, newArgOperands, rewriter ); diff --git a/test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk b/test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk new file mode 100644 index 000000000..e1655de75 --- /dev/null +++ b/test/Transforms/PodToScalar/call_with_pod_read_array_of_pod.llzk @@ -0,0 +1,30 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#outer = affine_map<()[s0] -> (s0)> +!Pair = !pod.type<[@x: index, @y: !felt.type]> +!Holder = !pod.type<[@pairs: !array.type<#outer x !Pair>]> +module attributes {llzk.lang} { + function.def @sink(%arr: !array.type<#outer x !Pair>) -> index { + %c0 = arith.constant 0 : index + %len = array.len %arr, %c0 : !array.type<#outer x !Pair> + function.return %len : index + } + + function.def @main(%holder: !Holder) -> index { + %pairs = pod.read %holder[@pairs] : !Holder, !array.type<#outer x !Pair> + %len = function.call @sink(%pairs) : (!array.type<#outer x !Pair>) -> index + function.return %len : index + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @sink(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> index { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = array.len %[[VAL_0]], %[[VAL_2]] : <#[[$M]] x index> +// CHECK-NEXT: function.return %[[VAL_3]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @main(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> index { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = function.call @sink(%[[VAL_4]], %[[VAL_5]]) : (!array.type<#[[$M]] x index>, !array.type<#[[$M]] x !felt.type>) -> index +// CHECK-NEXT: function.return %[[VAL_6]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } From 1401bf2e331f1a4066854222cf3cbfc2c51d7fc5 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 11:40:56 -0500 Subject: [PATCH 32/36] fix: Preserve member read offsets when splitting POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 22 ++++++++- ...ead_column_affine_offset_array_of_pod.llzk | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index ed619030d..37cf2e9fb 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -2276,12 +2276,32 @@ class SplitPodArrayInMemberReadOp : public OpConversionPattern { SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + SmallVector mapOperands; + std::optional numDimsPerMap; + auto mapOperandsOld = adaptor.getMapOperands(); + if (!mapOperandsOld.empty()) { + assert( + mapOperandsOld.size() == 1 && + "member.readm should have at most one affine-map operand group" + ); + mapOperands = flattenConvertedValues(mapOperandsOld.front()); + + ArrayRef numDimsPerMapOld = op.getNumDimsPerMap(); + if (!numDimsPerMapOld.empty()) { + assert( + numDimsPerMapOld.size() == 1 && + "member.readm should have one numDims entry per affine-map group" + ); + numDimsPerMap = numDimsPerMapOld.front(); + } + } SmallVector replacements; replacements.reserve(splitIds.size()); for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { const MemberInfo &newMember = idToMember.at(id); replacements.push_back(rewriter.create( - op.getLoc(), splitType, getSingleConvertedValue(adaptor.getComponent()), newMember.first + op.getLoc(), splitType, getSingleConvertedValue(adaptor.getComponent()), newMember.first, + op.getTableOffset().value_or(nullptr), mapOperands, numDimsPerMap )); } rewriter.replaceOpWithMultiple(op, {ValueRange(replacements)}); diff --git a/test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk b/test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk new file mode 100644 index 000000000..214182f72 --- /dev/null +++ b/test/Transforms/PodToScalar/member_read_column_affine_offset_array_of_pod.llzk @@ -0,0 +1,47 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +#map = affine_map<()[s0] -> (s0 - 1)> +module attributes {llzk.lang} { + struct.def @S { + struct.member @m : !array.type<2 x !pod.type<[@x: !felt.type, @y: !felt.type]>> {column} + + function.def @compute(%idx: index) -> !struct.type<@S> { + %self = struct.new : !struct.type<@S> + function.return %self : !struct.type<@S> + } + + function.def @constrain(%self: !struct.type<@S>, %idx: index) { + %arr = struct.readm %self[@m] {()[%idx]} : + !struct.type<@S>, !array.type<2 x !pod.type<[@x: !felt.type, @y: !felt.type]>> + {tableOffset = #map} + %c0 = arith.constant 0 : index + %elt = array.read %arr[%c0] : + !array.type<2 x !pod.type<[@x: !felt.type, @y: !felt.type]>>, + !pod.type<[@x: !felt.type, @y: !felt.type]> + %x = pod.read %elt[@x] : !pod.type<[@x: !felt.type, @y: !felt.type]>, !felt.type + %y = pod.read %elt[@y] : !pod.type<[@x: !felt.type, @y: !felt.type]>, !felt.type + constrain.eq %x, %y : !felt.type + function.return + } + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 - 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @S { +// CHECK-NEXT: struct.member @m_x : !array.type<2 x !felt.type> {column} +// CHECK-NEXT: struct.member @m_y : !array.type<2 x !felt.type> {column} +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: index) -> !struct.type<@S> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@S> +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@S> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@S>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: index) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@m_x] {(){{\[}}%[[VAL_3]]]} : <@S>, !array.type<2 x !felt.type> {tableOffset = #[[$M]]} +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@m_y] {(){{\[}}%[[VAL_3]]]} : <@S>, !array.type<2 x !felt.type> {tableOffset = #[[$M]]} +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_4]]{{\[}}%[[VAL_6]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_5]]{{\[}}%[[VAL_6]]] : <2 x !felt.type>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_7]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From 74d0b45bfc058ff50e44632e3bc3983e31ad2bcb Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 15:54:52 -0500 Subject: [PATCH 33/36] fix: Preserve the source shape for affine empty-POD arrays --- .../POD/Transforms/PodToScalarPass.cpp | 315 +++++++++++++++--- test/Transforms/PodToScalar/array_length.llzk | 45 ++- 2 files changed, 298 insertions(+), 62 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 37cf2e9fb..42ab602bb 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -383,6 +383,63 @@ static size_t splitPodArrayTypeTo( return 1; } +/// Return the index-array carrier type used to preserve the shape of a zero-leaf array-of-POD. +static ArrayType getZeroLeafPodArrayShapeCarrierType(ArrayType arrTy) { + return arrTy.cloneWith(IndexType::get(arrTy.getContext())); +} + +/// Return `true` iff splitting `arrTy` produces no concrete POD leaf arrays. +static bool hasZeroLeafPodArraySplit(ArrayType arrTy) { + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes); + return splitTypes.empty(); +} + +/// Convert one type using the step-2 array-of-POD lowering convention. +/// +/// Most arrays-of-POD expand to one parallel array per POD leaf. When the element POD has no +/// leaves, keep a single index-array carrier so later rewrites can still preserve shape and affine +/// instantiation information. +static size_t convertPodArrayTypeTo(Type t, SmallVectorImpl &collect) { + if (ArrayType arrTy = splittablePodArray(t)) { + size_t oldSize = collect.size(); + splitPodArrayTypeTo(arrTy, collect); + if (collect.size() == oldSize) { + collect.push_back(getZeroLeafPodArrayShapeCarrierType(arrTy)); + } + return collect.size() - oldSize; + } + + collect.push_back(t); + return 1; +} + +/// For each Type in the given input collection, call `convertPodArrayTypeTo(Type,...)`. +template +inline void convertPodArrayTypesTo( + TypeCollection types, SmallVectorImpl &collect, + SmallVector *originalIdxToSize = nullptr +) { + if (originalIdxToSize) { + originalIdxToSize->reserve(types.size()); + } + for (Type t : types) { + size_t count = convertPodArrayTypeTo(t, collect); + if (originalIdxToSize) { + originalIdxToSize->push_back(count); + } + } +} + +/// Return the step-2 converted types for the given collection. +template +static SmallVector +convertPodArrayTypes(TypeCollection types, SmallVector *originalIdxToSize = nullptr) { + SmallVector collect; + convertPodArrayTypesTo(types, collect, originalIdxToSize); + return collect; +} + /// For each Type in the given input collection, call `splitPodArrayTypeTo(Type,...)`. template inline void splitPodArrayTypeTo( @@ -463,6 +520,50 @@ inline static Value getSingleConvertedValue(ValueRange values) { return values.front(); } +/// Store the affine-map operand groups needed to rebuild one concrete array instantiation. +/// +/// The layout mirrors `array.new`: `mapOperandStorage` keeps each instantiation group separately, +/// and `numDimsPerMap` records how many values in each group are dimensional arguments. +struct ArrayInstantiationInfo { + SmallVector> mapOperandStorage; + SmallVector numDimsPerMap; +}; + +/// Try to recover affine-map instantiation operands from a concrete array-producing value. +/// +/// This peels compatibility casts, follows simple `pod.read` to dominating `pod.write` +/// forwarding, and succeeds only when the value ultimately traces back to a concrete +/// `array.new` carrying the instantiation groups. +static std::optional tryGetArrayInstantiationInfo(Value value) { + while (auto cast = value.getDefiningOp()) { + value = cast.getInput(); + } + + if (ReadPodOp read = value.getDefiningOp()) { + if (WritePodOp write = findNearestForwardableWriteInBlock(read)) { + return tryGetArrayInstantiationInfo(write.getValue()); + } + return std::nullopt; + } + + auto create = value.getDefiningOp(); + if (!create) { + return std::nullopt; + } + + ArrayInstantiationInfo info; + info.mapOperandStorage.reserve(create.getMapOperands().size()); + for (OperandRange group : create.getMapOperands()) { + info.mapOperandStorage.emplace_back(group.begin(), group.end()); + } + + if (DenseI32ArrayAttr numDimsPerMap = create.getNumDimsPerMapAttr()) { + llvm::append_range(info.numDimsPerMap, numDimsPerMap.asArrayRef()); + } + + return info; +} + /// Materialize a scalar array value that preserves the shape of `originalArrTy`. /// /// This is used as a shape-only carrier for `array.len` when an array-of-POD splits to @@ -470,7 +571,7 @@ inline static Value getSingleConvertedValue(ValueRange values) { static Value materializeArrayLengthCarrier( Value originalArrRef, ArrayType originalArrTy, Location loc, ConversionPatternRewriter &rewriter ) { - ArrayType carrierTy = originalArrTy.cloneWith(IndexType::get(rewriter.getContext())); + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(originalArrTy); if (auto create = originalArrRef.getDefiningOp()) { if (create.getMapOperands().empty()) { @@ -487,6 +588,22 @@ static Value materializeArrayLengthCarrier( ); } + if (std::optional instantiation = + tryGetArrayInstantiationInfo(originalArrRef)) { + if (instantiation->mapOperandStorage.empty()) { + return rewriter.create(loc, carrierTy); + } + + SmallVector mapOperands; + mapOperands.reserve(instantiation->mapOperandStorage.size()); + for (const SmallVector &group : instantiation->mapOperandStorage) { + mapOperands.push_back(group); + } + return rewriter.create( + loc, carrierTy, mapOperands, ArrayRef(instantiation->numDimsPerMap) + ); + } + bool hasAffineDims = llvm::any_of(originalArrTy.getDimensionSizes(), [](Attribute dimSize) { return llvm::isa(dimSize); }); @@ -554,15 +671,6 @@ inline static Value createWritableArrayValue(OpBuilder &bldr, Location loc, Arra } } -/// Store the affine-map operand groups needed to rebuild one concrete array instantiation. -/// -/// The layout mirrors `array.new`: `mapOperandStorage` keeps each instantiation group separately, -/// and `numDimsPerMap` records how many values in each group are dimensional arguments. -struct ArrayInstantiationInfo { - SmallVector> mapOperandStorage; - SmallVector numDimsPerMap; -}; - /// Return `true` iff two recovered array instantiations can be rebuilt identically. static bool equivalentArrayInstantiationInfo( const ArrayInstantiationInfo &lhs, const ArrayInstantiationInfo &rhs @@ -581,41 +689,6 @@ static bool equivalentArrayInstantiationInfo( return true; } -/// Try to recover affine-map instantiation operands from a concrete array-producing value. -/// -/// This peels compatibility casts, follows simple `pod.read` to dominating `pod.write` -/// forwarding, and succeeds only when the value ultimately traces back to a concrete -/// `array.new` carrying the instantiation groups. -static std::optional tryGetArrayInstantiationInfo(Value value) { - while (auto cast = value.getDefiningOp()) { - value = cast.getInput(); - } - - if (ReadPodOp read = value.getDefiningOp()) { - if (WritePodOp write = findNearestForwardableWriteInBlock(read)) { - return tryGetArrayInstantiationInfo(write.getValue()); - } - return std::nullopt; - } - - auto create = value.getDefiningOp(); - if (!create) { - return std::nullopt; - } - - ArrayInstantiationInfo info; - info.mapOperandStorage.reserve(create.getMapOperands().size()); - for (OperandRange group : create.getMapOperands()) { - info.mapOperandStorage.emplace_back(group.begin(), group.end()); - } - - if (DenseI32ArrayAttr numDimsPerMap = create.getNumDimsPerMapAttr()) { - llvm::append_range(info.numDimsPerMap, numDimsPerMap.asArrayRef()); - } - - return info; -} - /// Describe whether a set of leaf arrays shares one recoverable instantiation. enum class CommonArrayInstantiationStatus : std::uint8_t { unavailable, @@ -1268,6 +1341,12 @@ class SplitPodArrayInMemberDefOp : public OpConversionPattern { SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(arrTy); + rewriter.modifyOpInPlace(op, [&]() { op.setType(carrierTy); }); + localRepMapRef[RecordChain()] = std::make_pair(op.getSymNameAttr(), carrierTy); + return; + } SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct); for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { @@ -1316,7 +1395,7 @@ class PodArrayTypeConverter : public TypeConverter { if (!splittablePodArray(arrTy)) { return std::nullopt; } - splitPodArrayTypeTo(arrTy, results); + convertPodArrayTypeTo(arrTy, results); return success(); } ); @@ -1345,6 +1424,12 @@ class SplitPodArrayNonDetOp : public OpConversionPattern { void rewrite(NonDetOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { SmallVector splitTypes; splitPodArrayTypeTo(op.getType(), splitTypes); + if (splitTypes.empty()) { + rewriter.replaceOpWithNewOp( + op, getZeroLeafPodArrayShapeCarrierType(llvm::cast(op.getType())) + ); + return; + } SmallVector replacements; replacements.reserve(splitTypes.size()); for (Type splitType : splitTypes) { @@ -1380,6 +1465,28 @@ class SplitPodArrayCreateArrayOp : public OpConversionPattern { SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(arrTy); + if (adaptor.getMapOperands().empty()) { + rewriter.replaceOpWithNewOp(op, carrierTy); + return success(); + } + + SmallVector> mapOperandStorage; + SmallVector mapOperands; + mapOperandStorage.reserve(adaptor.getMapOperands().size()); + mapOperands.reserve(adaptor.getMapOperands().size()); + for (ArrayRef mapOperandGroup : adaptor.getMapOperands()) { + mapOperandStorage.push_back(flattenConvertedValues(mapOperandGroup)); + } + for (const SmallVector &values : mapOperandStorage) { + mapOperands.push_back(values); + } + rewriter.replaceOpWithNewOp( + op, carrierTy, mapOperands, op.getNumDimsPerMapAttr() + ); + return success(); + } SmallVector replacements; replacements.reserve(splitTypes.size()); @@ -1508,6 +1615,10 @@ class SplitPodArrayReadArrayOp : public OpConversionPattern { SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + rewriter.replaceOpWithNewOp(op, podTy); + return success(); + } SmallVector indices = flattenConvertedValues(adaptor.getIndices()); NewPodOp pod = rewriter.create(op.getLoc(), podTy); @@ -1548,6 +1659,10 @@ class SplitPodArrayWriteArrayOp : public OpConversionPattern { SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + rewriter.eraseOp(op); + return success(); + } SmallVector indices = flattenConvertedValues(adaptor.getIndices()); Value podValue = getSingleConvertedValue(adaptor.getRvalue()); @@ -1597,9 +1712,9 @@ class SplitPodArrayInFuncDefOp : public OpConversionPattern { } SmallVector originalInputIdxToSize, originalResultIdxToSize; - SmallVector newInputs = splitPodArrayType(oldTy.getInputs(), &originalInputIdxToSize); + SmallVector newInputs = convertPodArrayTypes(oldTy.getInputs(), &originalInputIdxToSize); SmallVector newResultsWithSizeInfo = - splitPodArrayType(oldTy.getResults(), &originalResultIdxToSize); + convertPodArrayTypes(oldTy.getResults(), &originalResultIdxToSize); assert( newResultsWithSizeInfo == newResults && "expected array-of-pod type conversion to match function result attr replication" @@ -1654,6 +1769,17 @@ static void collectSplitPodArrayOperandValues( SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + ArrayType carrierTy = getZeroLeafPodArrayShapeCarrierType(arrTy); + if (!convertedValues.empty()) { + newOperands.push_back(castValueToTypeIfNeeded( + rewriter, loc, getSingleConvertedValue(convertedValues), carrierTy + )); + return; + } + newOperands.push_back(materializeArrayLengthCarrier(originalOperand, arrTy, loc, rewriter)); + return; + } auto isDirectAggregateToSplitCast = [&convertedValues, &splitTypes]() { if (convertedValues.empty()) { @@ -1728,6 +1854,20 @@ class SplitPodArrayInUnifiableCastOp : public OpConversionPattern { auto newResultIt = newCall.getResults().begin(); for (Type oldResultType : op.getResultTypes()) { SmallVector convertedTypes; - (void)splitPodArrayTypeTo(oldResultType, convertedTypes); + (void)convertPodArrayTypeTo(oldResultType, convertedTypes); SmallVector replacementsForResult; replacementsForResult.reserve(convertedTypes.size()); for (size_t i = 0; i < convertedTypes.size(); ++i) { @@ -1860,6 +2000,28 @@ class SplitPodArrayInEmitEqualityOp : public OpConversionPattern( + op.getLoc(), rewriter.getIndexAttr(llzk::checkedCast(dim)) + ); + Value lhsLen = rewriter.create(op.getLoc(), lhsCarrier, dimVal); + Value rhsLen = rewriter.create(op.getLoc(), rhsCarrier, dimVal); + rewriter.create(op.getLoc(), lhsLen, rhsLen); + } + rewriter.eraseOp(op); + return success(); + } + if (adaptor.getLhs().size() != adaptor.getRhs().size()) { return rewriter.notifyMatchFailure( op, "expected array-of-pod equality operands to expand to the same number of leaves" @@ -1905,6 +2067,11 @@ class SplitPodArrayInEmitContainmentOp : public OpConversionPattern podLeaves; processInputOperand(loc, getSingleConvertedValue(convertedValues), podLeaves, rewriter); @@ -1933,7 +2100,8 @@ class SplitPodArrayInEmitContainmentOp : public OpConversionPattern= rhsRank && "constrain.in verifier should reject higher-rank rhs arrays"); size_t selectedDims = lhsRank - rhsRank; - SmallVector lhsLeaves(adaptor.getLhs().begin(), adaptor.getLhs().end()); + SmallVector lhsLeaves = + collectContainmentLeaves(loc, op.getLhs(), adaptor.getLhs(), rewriter); SmallVector rhsLeaves = collectContainmentLeaves(loc, op.getRhs(), adaptor.getRhs(), rewriter); if (lhsLeaves.size() != rhsLeaves.size()) { @@ -1942,9 +2110,11 @@ class SplitPodArrayInEmitContainmentOp : public OpConversionPattern(loc, rewriter.getIndexAttr(0)); Value trueVal = rewriter.create( loc, IntegerAttr::get(IntegerType::get(rewriter.getContext(), 1), 1) @@ -2037,6 +2207,9 @@ static Value rebuildSplitPodArrayQuantifierIterValue( SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(sortType, splitTypes, &splitIds); + if (splitTypes.empty()) { + return bldr.create(loc, llvm::cast(iterType)); + } assert( convertedSort.size() == splitIds.size() && "converted quantifier sort must provide one value per POD-array leaf" @@ -2155,6 +2328,14 @@ class SplitPodArrayExtractArrayOp : public OpConversionPattern { SmallVector splitResultTypes; splitPodArrayTypeTo(op.getResult().getType(), splitResultTypes); + if (splitResultTypes.empty()) { + ArrayType resultTy = llvm::cast(op.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, getZeroLeafPodArrayShapeCarrierType(resultTy), + getSingleConvertedValue(adaptor.getArrRef()), flattenConvertedValues(adaptor.getIndices()) + ); + return success(); + } SmallVector indices = flattenConvertedValues(adaptor.getIndices()); SmallVector replacements; @@ -2186,6 +2367,15 @@ class SplitPodArrayInsertArrayOp : public OpConversionPattern { return failure(); } + if (hasZeroLeafPodArraySplit(llvm::cast(op.getRvalue().getType()))) { + rewriter.create( + op.getLoc(), getSingleConvertedValue(adaptor.getArrRef()), + flattenConvertedValues(adaptor.getIndices()), getSingleConvertedValue(adaptor.getRvalue()) + ); + rewriter.eraseOp(op); + return success(); + } + SmallVector indices = flattenConvertedValues(adaptor.getIndices()); for (auto [splitArrRange, splitRvalueRange] : llvm::zip_equal(adaptor.getArrRef(), adaptor.getRvalue())) { @@ -2231,6 +2421,15 @@ class SplitPodArrayInMemberWriteOp : public OpConversionPattern { SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + if (splitTypes.empty()) { + const MemberInfo &carrierMember = idToMember.at(RecordChain()); + rewriter.create( + op.getLoc(), getSingleConvertedValue(adaptor.getComponent()), + FlatSymbolRefAttr::get(carrierMember.first), getSingleConvertedValue(adaptor.getVal()) + ); + rewriter.eraseOp(op); + return success(); + } for (auto [id, splitValRange] : llvm::zip_equal(splitIds, adaptor.getVal())) { const MemberInfo &newMember = idToMember.at(id); @@ -2275,7 +2474,6 @@ class SplitPodArrayInMemberReadOp : public OpConversionPattern { SmallVector splitIds; SmallVector splitTypes; splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); - SmallVector mapOperands; std::optional numDimsPerMap; auto mapOperandsOld = adaptor.getMapOperands(); @@ -2295,6 +2493,15 @@ class SplitPodArrayInMemberReadOp : public OpConversionPattern { numDimsPerMap = numDimsPerMapOld.front(); } } + if (splitTypes.empty()) { + const MemberInfo &carrierMember = idToMember.at(RecordChain()); + Value carrierRead = rewriter.create( + op.getLoc(), carrierMember.second, getSingleConvertedValue(adaptor.getComponent()), + carrierMember.first, op.getTableOffset().value_or(nullptr), mapOperands, numDimsPerMap + ); + rewriter.replaceOpWithMultiple(op, {ValueRange {carrierRead}}); + return success(); + } SmallVector replacements; replacements.reserve(splitIds.size()); for (auto [id, splitType] : llvm::zip_equal(splitIds, splitTypes)) { diff --git a/test/Transforms/PodToScalar/array_length.llzk b/test/Transforms/PodToScalar/array_length.llzk index 2ea7734d1..0257ee0c0 100644 --- a/test/Transforms/PodToScalar/array_length.llzk +++ b/test/Transforms/PodToScalar/array_length.llzk @@ -57,10 +57,11 @@ module attributes {llzk.lang} { function.return %len : index } } +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> // CHECK-LABEL: module attributes {llzk.lang} { // CHECK-NEXT: function.def @len_empty_leaf_array(%[[N:[0-9a-zA-Z_\.]+]]: index, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { -// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new{()[%[[N]]]} : <#map x index> -// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#map x index> +// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new{()[%[[N]]]} : <#[[$MAP]] x index> +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#[[$MAP]] x index> // CHECK-NEXT: function.return %[[LEN]] : index // CHECK-NEXT: } // CHECK-NEXT: } @@ -73,8 +74,7 @@ module attributes {llzk.lang} { } } // CHECK-LABEL: module attributes {llzk.lang} { -// CHECK-NEXT: function.def @len_empty_leaf_static_arg(%[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { -// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = array.new : <4 x index> +// CHECK-NEXT: function.def @len_empty_leaf_static_arg(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<4 x index>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { // CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <4 x index> // CHECK-NEXT: function.return %[[LEN]] : index // CHECK-NEXT: } @@ -90,11 +90,40 @@ module attributes {llzk.lang} { function.return %len : index } } -// CHECK: #map = affine_map<()[s0] -> (s0)> +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> // CHECK-LABEL: module attributes {llzk.lang} { -// CHECK-NEXT: function.def @len_empty_leaf_affine_arg(%[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { -// CHECK-NEXT: %[[ARR:[0-9a-zA-Z_\.]+]] = llzk.nondet : !array.type<#map x index> -// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#map x index> +// CHECK-NEXT: function.def @len_empty_leaf_affine_arg(%[[ARR:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, %[[DIM:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN:[0-9a-zA-Z_\.]+]] = array.len %[[ARR]], %[[DIM]] : <#[[$MAP]] x index> // CHECK-NEXT: function.return %[[LEN]] : index // CHECK-NEXT: } // CHECK-NEXT: } +// ----- + +#map_call = affine_map<()[s0] -> (s0)> +module attributes {llzk.lang} { + function.def @len_empty_leaf_affine_sink( + %arr: !array.type<#map_call x !pod.type<[]>>, %dim: index + ) -> index { + %len = array.len %arr, %dim : !array.type<#map_call x !pod.type<[]>> + function.return %len : index + } + + function.def @len_empty_leaf_affine_call( + %arr: !array.type<#map_call x !pod.type<[]>>, %dim: index + ) -> index { + %len = function.call @len_empty_leaf_affine_sink(%arr, %dim) + : (!array.type<#map_call x !pod.type<[]>>, index) -> index + function.return %len : index + } +} +// CHECK: #[[$MAP:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @len_empty_leaf_affine_sink(%[[ARR0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, %[[DIM0:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN0:[0-9a-zA-Z_\.]+]] = array.len %[[ARR0]], %[[DIM0]] : <#[[$MAP]] x index> +// CHECK-NEXT: function.return %[[LEN0]] : index +// CHECK-NEXT: } +// CHECK-NEXT: function.def @len_empty_leaf_affine_call(%[[ARR1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$MAP]] x index>, %[[DIM1:[0-9a-zA-Z_\.]+]]: index) -> index { +// CHECK-NEXT: %[[LEN1:[0-9a-zA-Z_\.]+]] = function.call @len_empty_leaf_affine_sink(%[[ARR1]], %[[DIM1]]) : (!array.type<#[[$MAP]] x index>, index) -> index +// CHECK-NEXT: function.return %[[LEN1]] : index +// CHECK-NEXT: } +// CHECK-NEXT: } From bd186a3cd4fc38b80ac981f50be3c1a28eab2154 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 16:26:03 -0500 Subject: [PATCH 34/36] fix: Use a writable affine-aware array for static POD-array leaves --- .../POD/Transforms/PodToScalarPass.cpp | 40 +++++++++---------- .../array_new_affine_leaf_array.llzk | 28 +++++++++++++ 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 42ab602bb..57e856c3b 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -904,32 +904,30 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef(getFlattenedTypeAlongPath(valueType, recordChain)); + SmallVector splitIds; + SmallVector splitTypes; + splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); + auto *splitIt = llvm::find(splitIds, RecordChain(recordChain)); + assert(splitIt != splitIds.end() && "record path must name a flattened POD array leaf"); + size_t splitIdx = std::distance(splitIds.begin(), splitIt); + + SmallVector leafArrays; + if (tryCollectDirectSplitPodArrayLeafValues(value, arrTy, splitTypes, leafArrays)) { + return leafArrays[splitIdx]; + } + + Value strippedValue = peelUnifiableCasts(value); + if (strippedValue.getDefiningOp()) { + auto splitLeafReads = + bldr.create(loc, TypeRange(splitTypes), strippedValue); + return splitLeafReads.getResult(splitIdx); + } if (isFreshUnwrittenPodArrayRead(value)) { return createWritableArrayValue(bldr, loc, splitArrTy); } if (!arrTy.hasStaticShape()) { - SmallVector splitIds; - SmallVector splitTypes; - splitPodArrayTypeTo(arrTy, splitTypes, &splitIds); - - SmallVector leafArrays; - if (tryCollectDirectSplitPodArrayLeafValues(value, arrTy, splitTypes, leafArrays)) { - auto *it = llvm::find(splitIds, RecordChain(recordChain)); - assert(it != splitIds.end() && "record path must name a flattened POD array leaf"); - return leafArrays[std::distance(splitIds.begin(), it)]; - } - - Value strippedValue = peelUnifiableCasts(value); - if (strippedValue.getDefiningOp()) { - auto splitLeafReads = - bldr.create(loc, TypeRange(splitTypes), strippedValue); - auto *it = llvm::find(splitIds, RecordChain(recordChain)); - assert(it != splitIds.end() && "record path must name a flattened POD array leaf"); - return splitLeafReads.getResult(std::distance(splitIds.begin(), it)); - } - llvm_unreachable( "non-static nested array-of-POD scalarization requires split-array backing or an " "uninitialized pod field" @@ -939,7 +937,7 @@ genReadAlongPath(OpBuilder &bldr, Location loc, Value value, ArrayRef(loc, splitArrTy); + Value splitArray = createWritableArrayValue(bldr, loc, splitArrTy); for (ArrayAttr index : *subIndices) { Value element = genArrayRead(bldr, loc, value, index); Value leafValue = genReadAlongPath(bldr, loc, element, recordChain); diff --git a/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk b/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk index 54873c893..9297ccf06 100644 --- a/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk +++ b/test/Transforms/PodToScalar/array_new_affine_leaf_array.llzk @@ -55,3 +55,31 @@ module attributes {llzk.lang} { // CHECK-NEXT: function.return %[[VAL_3]] : !array.type<2,#[[$M]] x index> // CHECK-NEXT: } // CHECK-NEXT: } +// ----- + +#inner = affine_map<()[s0] -> (s0 + 1)> +!Leaf = !array.type<#inner x index> +!Elem = !pod.type<[@vals: !Leaf]> +!ElemArray = !array.type<2 x !Elem> +!Outer = !pod.type<[@items: !ElemArray]> +module attributes {llzk.lang} { + function.def @sink(%arr: !ElemArray) { + function.return + } + + function.def @source(%p: !Outer) { + %items = pod.read %p[@items] : !Outer, !ElemArray + function.call @sink(%items) : (!ElemArray) -> () + function.return + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @sink(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<2,#[[$M]] x index>) { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: function.def @source(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2,#[[$M]] x index>) { +// CHECK-NEXT: function.call @sink(%[[VAL_1]]) : (!array.type<2,#[[$M]] x index>) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } From fe5f9a555e182186ecb6fd4a9a49087674cd9f6c Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 18:00:13 -0500 Subject: [PATCH 35/36] fix: Rewrite casts whose input is an array-of-POD --- .../POD/Transforms/PodToScalarPass.cpp | 61 +++++++++++++++++-- .../unifiable_cast_array_of_pod.llzk | 24 ++++++++ 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index 57e856c3b..a2d8b0dd9 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -624,6 +624,15 @@ static SmallVector flattenConvertedValues(RangeOfRanges ranges) { return values; } +/// Return `true` iff the inputs are the same size and each value type in `values` unifies +/// with the corresponding `types` entry. +template +inline static bool allValueTypesUnifyWithTypes(const ValueRangeLike &values, ArrayRef types) { + return llvm::all_of_zip(values, types, [](auto value, Type type) { + return typesUnify(value.getType(), type); + }); +} + /// Replace any AffineMap-backed array dimensions nested within `type` with wildcard `?` dims. /// /// This preserves the overall array nesting while erasing only the affine-map dimensions that @@ -1407,6 +1416,7 @@ class PodArrayTypeConverter : public TypeConverter { }; addTargetMaterialization(materializeCast); addArgumentMaterialization(materializeCast); + addSourceMaterialization(materializeCast); } }; @@ -1798,10 +1808,7 @@ static void collectSplitPodArrayOperandValues( }); }; - if (!isDirectAggregateToSplitCast() && convertedValues.size() == splitTypes.size() && - llvm::all_of(llvm::zip_equal(convertedValues, splitTypes), [](auto pair) { - return typesUnify(std::get<0>(pair).getType(), std::get<1>(pair)); - })) { + if (!isDirectAggregateToSplitCast() && allValueTypesUnifyWithTypes(convertedValues, splitTypes)) { llvm::append_range(newOperands, convertedValues); return; } @@ -1817,7 +1824,9 @@ class SplitPodArrayInUnifiableCastOp : public OpConversionPattern::OpConversionPattern; - static bool legal(UnifiableCastOp op) { return !splittablePodArray(op.getType()); } + static bool legal(UnifiableCastOp op) { + return !splittablePodArray(op.getType()) && !splittablePodArray(op.getInput().getType()); + } LogicalResult matchAndRewrite( UnifiableCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter @@ -1827,6 +1836,47 @@ class SplitPodArrayInUnifiableCastOp : public OpConversionPattern inputSplitTypes; + splitPodArrayTypeTo(inputArrTy, inputSplitTypes); + + SmallVector splitInputs; + collectSplitPodArrayOperandValues( + op.getLoc(), op.getInput(), adaptor.getInput(), splitInputs, rewriter + ); + if (inputSplitTypes.empty()) { + if (splitInputs.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected one shape carrier for zero-leaf array-of-pod cast input" + ); + } + } else if (splitInputs.size() != inputSplitTypes.size()) { + return rewriter.notifyMatchFailure( + op, "failed to collect one split input per array-of-pod cast leaf" + ); + } + + // `poly.unifiable_cast` to a non-array target cannot preserve all split leaf values in one + // SSA value without reintroducing aggregate array-of-POD materialization. + auto *it = llvm::find_if(splitInputs, [&op](Value v) { + return typesUnify(v.getType(), op.getType()); + }); + if (it == splitInputs.end()) { + return rewriter.notifyMatchFailure( + op, "failed to find split array leaf type compatible with cast target" + ); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), castValueToTypeIfNeeded(rewriter, op.getLoc(), *it, op.getType()) + ); + return success(); + } + if (!inputArrTy) { return rewriter.notifyMatchFailure( op, "expected array-of-pod cast input when rewriting array-of-pod cast result" @@ -1837,7 +1887,6 @@ class SplitPodArrayInUnifiableCastOp : public OpConversionPattern inputSplitTypes; splitPodArrayTypeTo(inputArrTy, inputSplitTypes, &inputSplitIds); - ArrayType resultArrTy = llvm::cast(op.getType()); SmallVector resultSplitIds; SmallVector resultSplitTypes; splitPodArrayTypeTo(resultArrTy, resultSplitTypes, &resultSplitIds); diff --git a/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk b/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk index 7f6cb3b10..21689a0b9 100644 --- a/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk +++ b/test/Transforms/PodToScalar/unifiable_cast_array_of_pod.llzk @@ -29,3 +29,27 @@ module attributes {llzk.lang} { // CHECK-NEXT: function.return %[[CAST_X]], %[[CAST_Y]], %[[CALL_LEN]] : !array.type, !array.type, index // CHECK-NEXT: } // CHECK-NEXT: } +// ----- + +#map = affine_map<()[s0] -> (s0)> +!Pair = !pod.type<[@x: index, @y: !felt.type]> +module attributes {llzk.lang} { + poly.template @CastToTypeVar { + poly.param @T_return : !poly.tvar<@T_return> + function.def @main(%arr: !array.type<#map x !Pair>) -> !poly.tvar<@T_return> { + %cast = poly.unifiable_cast %arr : (!array.type<#map x !Pair>) -> !poly.tvar<@T_return> + function.return %cast : !poly.tvar<@T_return> + } + } +} +// CHECK: #[[$M:[0-9a-zA-Z_\.]+]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: poly.template @CastToTypeVar { +// CHECK-NEXT: poly.param @T_return : !poly.tvar<@T_return> +// CHECK-NEXT: function.def @main(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x index>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<#[[$M]] x !felt.type>) -> !poly.tvar<@T_return> { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_0]] : (!array.type<#[[$M]] x index>) -> !poly.tvar<@T_return> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_2]] : (!poly.tvar<@T_return>) -> !poly.tvar<@T_return> +// CHECK-NEXT: function.return %[[VAL_3]] : !poly.tvar<@T_return> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } From ac74dfe60d53d527411e2b74ad510e75175518e7 Mon Sep 17 00:00:00 2001 From: Tim Hoffman Date: Thu, 2 Jul 2026 23:22:18 -0500 Subject: [PATCH 36/36] fix: Materialize virtual PODs where updated values dominate --- .../POD/Transforms/PodToScalarPass.cpp | 55 ++++++++++++++++++- .../whole_pod_use_after_virtual_write.llzk | 25 +++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index a2d8b0dd9..980f69691 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -1097,6 +1097,30 @@ materializeVirtualPod(OpBuilder &bldr, NewPodOp pod, const VirtualPodLeafMap &le } } +/// Return the latest same-block operation that defines one of `leafValues`, or `pod` itself. +/// +/// Virtual PODs created from split block arguments can be updated later with scalar values defined +/// after the placeholder. Replaying the deferred writes immediately after the placeholder can then +/// violate SSA dominance. Materializing after the latest same-block leaf definition preserves the +/// original write ordering for these straight-line updates while keeping the fallback local. +static Operation * +findVirtualPodMaterializationAnchor(NewPodOp pod, const VirtualPodLeafMap &leafValues) { + Operation *anchor = pod.getOperation(); + Block *block = anchor->getBlock(); + + for (const auto &it : leafValues) { + Operation *defOp = it.second.getDefiningOp(); + if (!defOp || defOp->getBlock() != block) { + continue; + } + if (anchor->isBeforeInBlock(defOp)) { + anchor = defOp; + } + } + + return anchor; +} + /// Return `true` iff a read from a virtual POD can be resolved without materializing it. static bool canResolveVirtualPodRead(ReadPodOp op, const VirtualPodValueMap &virtualPods) { if (!lookupVirtualPodLeafMap(op.getPodRef(), virtualPods) || hasEarlierWriteInBlock(op) || @@ -3411,7 +3435,7 @@ step3(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM continue; } if (auto newPod = llvm::dyn_cast(podValue.getDefiningOp())) { - builder.setInsertionPointAfter(newPod); + builder.setInsertionPointAfter(findVirtualPodMaterializationAnchor(newPod, leafValues)); materializeVirtualPod(builder, newPod, leafValues); } } @@ -4187,6 +4211,33 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(constrain::EmitEqualityOp op, PatternRewriter &rewriter) const override { + PodType podTy = splittablePod(op.getLhs().getType()); + if (!podTy) { + return failure(); + } + + SmallVector recordChain; + forEachPodLeaf(podTy, recordChain, [&rewriter, &op](const RecordChain &id, Type) { + Value lhsLeaf = genReadAlongPath(rewriter, op.getLoc(), op.getLhs(), id); + Value rhsLeaf = genReadAlongPath(rewriter, op.getLoc(), op.getRhs(), id); + rewriter.create(op.getLoc(), lhsLeaf, rhsLeaf); + }); + rewriter.eraseOp(op); + return success(); + } +}; + /// Apply a greedy rewrite/fold pass over the module body using the provided patterns. static LogicalResult applyGreedily(ModuleOp modOp, RewritePatternSet &&patterns, bool *changed = nullptr) { @@ -4203,7 +4254,7 @@ static LogicalResult step4(ModuleOp modOp) { patterns.add< FoldReadAfterWriteInBlockPattern, ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, LiftPodAccessesFromForLoopPattern, LiftPodAccessesFromWhileLoopPattern, - FoldIfCarriedPodReadAfterWritePattern>(patterns.getContext()); + FoldIfCarriedPodReadAfterWritePattern, SplitPodInEmitEqualityPattern>(patterns.getContext()); LLVM_DEBUG(llvm::dbgs() << "Begin step 4: refactor pod ops within SCF regions\n";); return applyGreedily(modOp, std::move(patterns)); diff --git a/test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk b/test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk new file mode 100644 index 000000000..100584416 --- /dev/null +++ b/test/Transforms/PodToScalar/whole_pod_use_after_virtual_write.llzk @@ -0,0 +1,25 @@ +// RUN: llzk-opt -llzk-pod-to-scalar %s | FileCheck --enable-var-scope %s + +!Single = !pod.type<[@x: index]> +module attributes {llzk.lang} { + function.def @whole_use_after_computed_write( + %lhs: !Single, %rhs: !Single, %a: index, %b: index + ) attributes {function.allow_constraint} { + %sum = arith.addi %a, %b : index + pod.write %lhs[@x] = %sum : !Single, index + constrain.eq %lhs, %rhs : !Single + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @whole_use_after_computed_write( +// CHECK-SAME: %[[UNUSED:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[RHS:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[A:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[B:[0-9a-zA-Z_\.]+]]: index +// CHECK-SAME: ) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[SUM:[0-9a-zA-Z_\.]+]] = arith.addi %[[A]], %[[B]] : index +// CHECK-NEXT: constrain.eq %[[SUM]], %[[RHS]] : index, index +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: }