From 4e7dc629e4919d5c8ea4031df9bbf9190345fc53 Mon Sep 17 00:00:00 2001 From: Martun Karapetyan Date: Fri, 12 Jul 2024 18:16:39 +0400 Subject: [PATCH] Optimized gate argument code. (#10) Co-authored-by: Martun Karapetyan --- .../crypto3/zk/math/expression_evaluator.hpp | 8 +- .../arithmetization/plonk/assignment.hpp | 53 +++++ .../arithmetization/plonk/constraint.hpp | 111 +++++---- .../plonk/placeholder/gates_argument.hpp | 74 +++--- .../plonk/placeholder/lookup_argument.hpp | 68 +++--- libs/parallel-zk/test/math/expression.cpp | 216 +++++++++--------- 6 files changed, 292 insertions(+), 238 deletions(-) diff --git a/libs/parallel-zk/include/nil/crypto3/zk/math/expression_evaluator.hpp b/libs/parallel-zk/include/nil/crypto3/zk/math/expression_evaluator.hpp index 2093dbd1..7ade1077 100644 --- a/libs/parallel-zk/include/nil/crypto3/zk/math/expression_evaluator.hpp +++ b/libs/parallel-zk/include/nil/crypto3/zk/math/expression_evaluator.hpp @@ -95,7 +95,7 @@ namespace nil { */ expression_evaluator( const math::expression& expr, - std::function get_var_value) + std::function get_var_value) : expr(expr) , get_var_value(get_var_value) { } @@ -140,7 +140,7 @@ namespace nil { const math::expression& expr; // A function used to retrieve the value of a variable. - std::function get_var_value; + std::function get_var_value; }; @@ -207,7 +207,7 @@ namespace nil { */ cached_expression_evaluator( const math::expression& expr, - std::function get_var_value) + std::function get_var_value) : _expr(expr) , _get_var_value(get_var_value) { } @@ -304,7 +304,7 @@ namespace nil { const math::expression& _expr; // A function used to retrieve the value of a variable. - std::function _get_var_value; + std::function _get_var_value; // Shows how many times each subexpression appears. We count have the expression // itself as a key, but apparently it's waay too slow. Just map the hash->count, assume diff --git a/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/assignment.hpp b/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/assignment.hpp index 1053c88c..435686ae 100644 --- a/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/assignment.hpp +++ b/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/assignment.hpp @@ -32,6 +32,7 @@ #include #include #include +#include namespace nil { namespace blueprint { @@ -55,6 +56,7 @@ namespace nil { class plonk_private_table { public: using witnesses_container_type = std::vector; + using VariableType = plonk_variable; protected: @@ -82,6 +84,30 @@ namespace nil { return _witnesses[index].size(); } + const ColumnType& get_variable_value_without_rotation(const VariableType& var) const { + switch (var.type) { + case VariableType::column_type::witness: + return witness(var.index); + case VariableType::column_type::public_input: + return public_input(var.index); + case VariableType::column_type::constant: + return constant(var.index); + case VariableType::column_type::selector: + return selector(var.index); + default: + std::cerr << "Invalid column type" << std::endl; + abort(); + } + } + ColumnType get_variable_value(const VariableType& var, std::shared_ptr> domain) const { + if (var.rotation == 0) { + return get_variable_value_without_rotation(var); + } + return math::polynomial_shift( + this->get_variable_value_without_rotation(var), + var.rotation, domain->m); + } + const ColumnType& witness(std::uint32_t index) const { assert(index < _witnesses.size()); return _witnesses[index]; @@ -126,6 +152,7 @@ namespace nil { using public_input_container_type = std::vector; using constant_container_type = std::vector; using selector_container_type = std::vector; + using VariableType = plonk_variable; protected: @@ -286,6 +313,7 @@ namespace nil { using public_input_container_type = typename public_table_type::public_input_container_type; using constant_container_type = typename public_table_type::constant_container_type; using selector_container_type = typename public_table_type::selector_container_type; + using VariableType = plonk_variable; protected: // These are normally created by the assigner, or read from a file. @@ -309,6 +337,31 @@ namespace nil { , _public_table(public_inputs_amount, constants_amount, selectors_amount) { } + const ColumnType& get_variable_value_without_rotation(const VariableType& var) const { + switch (var.type) { + case VariableType::column_type::witness: + return witness(var.index); + case VariableType::column_type::public_input: + return public_input(var.index); + case VariableType::column_type::constant: + return constant(var.index); + case VariableType::column_type::selector: + return selector(var.index); + default: + std::cerr << "Invalid column type" << std::endl; + abort(); + } + } + + ColumnType get_variable_value(const VariableType& var, std::shared_ptr> domain) const { + if (var.rotation == 0) { + return get_variable_value_without_rotation(var); + } + return math::polynomial_shift( + this->get_variable_value_without_rotation(var), + var.rotation, domain->m); + } + const ColumnType& witness(std::uint32_t index) const { return _private_table.witness(index); } diff --git a/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/constraint.hpp b/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/constraint.hpp index ae8565c0..a8eb4523 100644 --- a/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/constraint.hpp +++ b/libs/parallel-zk/include/nil/crypto3/zk/snark/arithmetization/plonk/constraint.hpp @@ -88,7 +88,7 @@ namespace nil { const plonk_assignment_table &assignments) const { math::expression_evaluator evaluator( *this, - [&assignments, row_index](const VariableType &var) { + [&assignments, row_index](const VariableType &var) -> const typename VariableType::assignment_type& { std::size_t rows_amount = assignments.rows_amount(); switch (var.type) { case VariableType::column_type::witness: @@ -100,8 +100,8 @@ namespace nil { case VariableType::column_type::selector: return assignments.selector(var.index)[(rows_amount + row_index + var.rotation) % rows_amount]; default: - BOOST_ASSERT_MSG(false, "Invalid column type"); - return VariableType::assignment_type::zero(); + std::cerr << "Invalid column type" << std::endl; + abort(); } }); @@ -109,39 +109,35 @@ namespace nil { } math::polynomial - evaluate(const plonk_polynomial_table &assignments, - std::shared_ptr> - domain) const { - using polynomial_type = math::polynomial; - using polynomial_variable_type = plonk_variable; - math::expression_variable_type_converter converter; - - math::expression_evaluator evaluator( - converter.convert(*this), - [&domain, &assignments](const VariableType &var) { - polynomial_type assignment; - switch (var.type) { - case VariableType::column_type::witness: - assignment = assignments.witness(var.index); - break; - case VariableType::column_type::public_input: - assignment = assignments.public_input(var.index); - break; - case VariableType::column_type::constant: - assignment = assignments.constant(var.index); - break; - case VariableType::column_type::selector: - assignment = assignments.selector(var.index); - break; - default: - BOOST_ASSERT_MSG(false, "Invalid column type"); - } + evaluate(const plonk_polynomial_table &assignments, + std::shared_ptr> domain) const { - if (var.rotation != 0) { - assignment = - math::polynomial_shift(assignment, domain->get_domain_element(var.rotation)); + using polynomial_type = math::polynomial; + using polynomial_variable_type = plonk_variable; + + // Convert scalar values to polynomials inside the expression. + math::expression_variable_type_converter converter; + auto converted_expression = converter.convert(*this); + + // For each variable with a rotation pre-compute its value. + std::unordered_map rotated_variable_values; + + math::expression_for_each_variable_visitor visitor( + [&rotated_variable_values, &assignments, &domain](const polynomial_variable_type& var) { + if (var.rotation == 0) + return; + rotated_variable_values[var] = assignments.get_variable_value(var, domain); + }); + visitor.visit(converted_expression); + + math::expression_evaluator evaluator( + converted_expression, + [&domain, &assignments, &rotated_variable_values] + (const VariableType &var) -> const polynomial_type& { + if (var.rotation == 0) { + return assignments.get_variable_value_without_rotation(var, domain); } - return assignment; + return rotated_variable_values[var]; }); return evaluator.evaluate(); } @@ -152,35 +148,33 @@ namespace nil { using polynomial_dfs_type = math::polynomial_dfs; using polynomial_dfs_variable_type = plonk_variable; + // Convert scalar values to polynomials inside the expression. math::expression_variable_type_converter converter( [&assignments](const typename VariableType::assignment_type& coeff) { polynomial_dfs_type(0, assignments.rows_amount(), coeff); }); - math::expression_evaluator evaluator( - converter.convert(*this), - [&domain, &assignments](const polynomial_dfs_variable_type &var) { - polynomial_dfs_type assignment; - switch (var.type) { - case VariableType::column_type::witness: - assignment = assignments.witness(var.index); - break; - case VariableType::column_type::public_input: - assignment = assignments.public_input(var.index); - break; - case VariableType::column_type::constant: - assignment = assignments.constant(var.index); - break; - case VariableType::column_type::selector: - assignment = assignments.selector(var.index); - break; - default: - BOOST_ASSERT_MSG(false, "Invalid column type"); - } - if (var.rotation != 0) { - assignment = math::polynomial_shift(assignment, var.rotation, domain->m); + auto converted_expression = converter.convert(*this); + + // For each variable with a rotation pre-compute its value. + std::unordered_map rotated_variable_values; + + math::expression_for_each_variable_visitor visitor( + [&rotated_variable_values, &assignments, &domain](const polynomial_dfs_variable_type& var) { + if (var.rotation == 0) + return ; + rotated_variable_values[var] = assignments.get_variable_value(var, domain); + }); + visitor.visit(converted_expression); + + math::expression_evaluator evaluator( + converted_expression, + [&domain, &assignments, &rotated_variable_values] + (const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& { + if (var.rotation == 0) { + return assignments.get_variable_value_without_rotation(var, domain); } - return assignment; + return rotated_variable_values[var]; } ); @@ -189,9 +183,10 @@ namespace nil { typename VariableType::assignment_type evaluate(detail::plonk_evaluation_map &assignments) const { + math::expression_evaluator evaluator( *this, - [&assignments](const VariableType &var) { + [&assignments](const VariableType &var) -> const typename VariableType::assignment_type& { std::tuple key = std::make_tuple(var.index, var.rotation, var.type); diff --git a/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/gates_argument.hpp b/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/gates_argument.hpp index 05f00bf7..37ee0ec1 100644 --- a/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/gates_argument.hpp +++ b/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/gates_argument.hpp @@ -100,9 +100,10 @@ namespace nil { variable_counts[var]++; }); + visitor.visit(expr); + std::shared_ptr> extended_domain = math::make_evaluation_domain(extended_domain_size); - visitor.visit(expr); parallel_for(0, variables.size(), [&variables, &variable_values_out, &assignments, &domain, &extended_domain, extended_domain_size](std::size_t i) { @@ -110,29 +111,9 @@ namespace nil { // We may have variable values in required sizes in some cases. if (variable_values_out[var].size() == extended_domain_size) return; - polynomial_dfs_type assignment; - switch (var.type) { - case polynomial_dfs_variable_type::column_type::witness: - assignment = assignments.witness(var.index); - break; - case polynomial_dfs_variable_type::column_type::public_input: - assignment = assignments.public_input(var.index); - break; - case polynomial_dfs_variable_type::column_type::constant: - assignment = assignments.constant(var.index); - break; - case polynomial_dfs_variable_type::column_type::selector: - assignment = assignments.selector(var.index); - break; - default: - std::cerr << "Invalid column type"; - std::abort(); - break; - } - if (var.rotation != 0) { - assignment = math::polynomial_shift(assignment, var.rotation, domain->m); - } + polynomial_dfs_type assignment = assignments.get_variable_value(var, domain); + // In parallel version we always resize the assignment poly, it's better for parallelization. // if (count > 1) { assignment.resize(extended_domain_size, domain, extended_domain); @@ -191,7 +172,25 @@ namespace nil { for (const auto& gate: gates) { std::vector> gate_results(extended_domain_sizes.size()); - for (const auto& constraint : gate.constraints) { + // We will split gates into parts especially for zkEVM circuit, since there is only 1 large gate with + // 683 constraints. Will split it into 24 parts, ~32 constraints each. + // This will mean our code will multiply by selector 16 times, instead of just once. But this is + // much better that losing parallelization. We do not want to re-write the whole code to try parallelize + // each gate compatation separately. This will not harm circuits with smaller number of terms much. + std::vector> gate_parts(extended_domain_sizes.size()); + std::vector gate_parts_constaint_counts(extended_domain_sizes.size()); + + + // This parameter can be tuned based on the circuit and the number of cores of the server on which the proofs + // are generated. On the current zkEVM circuit this value is optimal based on experiments. + const std::size_t constraint_limit = 16; + + + auto selector = polynomial_dfs_variable_type( + gate.selector_index, 0, false, polynomial_dfs_variable_type::column_type::selector); + + for (std::size_t constraint_idx = 0; constraint_idx < gate.constraints.size(); ++constraint_idx) { + const auto& constraint = gate.constraints[constraint_idx]; auto next_term = converter.convert(constraint) * value_type_to_polynomial_dfs(theta_acc); theta_acc *= theta; @@ -201,19 +200,26 @@ namespace nil { // Whatever the degree of term is, add it to the maximal degree expression. if (degree_limits[i] >= constraint_degree || i == 0) { gate_results[i] += next_term; + gate_parts[i] += next_term; + gate_parts_constaint_counts[i]++; + + // If we already have constraint_limit constaints in the gate_parts[i], add it to the 'subexpressions'. + if (gate_parts_constaint_counts[i] == constraint_limit) { + subexpressions[i].push_back(gate_parts[i] * selector); + gate_parts[i] = math::expression(); + gate_parts_constaint_counts[i] = 0; + } break; } + } } - auto selector = polynomial_dfs_variable_type( - gate.selector_index, 0, false, polynomial_dfs_variable_type::column_type::selector); - for (size_t i = 0; i < extended_domain_sizes.size(); ++i) { - gate_results[i] *= selector; // Only in parallel version we store the subexpressions of each expression and ignore the cache. - expressions[i] += gate_results[i]; - subexpressions[i].push_back(gate_results[i]); + expressions[i] += gate_results[i] * selector; + if (gate_parts_constaint_counts[i] != 0) + subexpressions[i].push_back(gate_parts[i] * selector); } } @@ -230,10 +236,12 @@ namespace nil { std::vector subvalues(subexpressions[i].size()); parallel_for(0, subexpressions[i].size(), [&subexpressions, &variable_values, &extended_domain_sizes, &subvalues, i](std::size_t subexpression_index) { - // Only in parallel version we store the subexpressions of each expression and ignore the cache, not using "cached_expression_evaluator". + // Only in parallel version we store the subexpressions of each expression and ignore the cache, + // not using "cached_expression_evaluator". math::expression_evaluator evaluator( - subexpressions[i][subexpression_index], [&assignments=variable_values, domain_size=extended_domain_sizes[i]] - (const polynomial_dfs_variable_type &var) { + subexpressions[i][subexpression_index], + [&assignments=variable_values, domain_size=extended_domain_sizes[i]] + (const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& { return assignments[var]; }); subvalues[subexpression_index] = evaluator.evaluate(); diff --git a/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/lookup_argument.hpp b/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/lookup_argument.hpp index 01eb567d..dd97f891 100644 --- a/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/lookup_argument.hpp +++ b/libs/parallel-zk/include/nil/crypto3/zk/snark/systems/plonk/placeholder/lookup_argument.hpp @@ -481,44 +481,14 @@ namespace nil { std::unique_ptr>> prepare_lookup_input() { PROFILE_PLACEHOLDER_SCOPE("Lookup argument preparing lookup input"); - - // Copied from gate argument. - // TODO: remove code duplication. - - - auto get_var_value = [&domain=basic_domain, &assignments=plonk_columns] - (const DfsVariableType &var) { - polynomial_dfs_type assignment; - switch (var.type) { - case DfsVariableType::column_type::witness: - assignment = assignments.witness(var.index); - break; - case DfsVariableType::column_type::public_input: - assignment = assignments.public_input(var.index); - break; - case DfsVariableType::column_type::constant: - assignment = assignments.constant(var.index); - break; - case DfsVariableType::column_type::selector: - assignment = assignments.selector(var.index); - break; - default: - std::cerr << "Invalid column type"; - std::abort(); - break; - } - - if (var.rotation != 0) { - assignment = math::polynomial_shift(assignment, var.rotation, domain->m); - } - return assignment; - }; - + + using polynomial_dfs_type = math::polynomial_dfs; + using polynomial_dfs_variable_type = plonk_variable; // Prepare lookup input - auto lookup_input_ptr = std::make_unique>>(); + auto lookup_input_ptr = std::make_unique>(); for (const auto &gate : lookup_gates) { - math::polynomial_dfs lookup_selector = plonk_columns.selector(gate.tag_index); + polynomial_dfs_type lookup_selector = plonk_columns.selector(gate.tag_index); // Increase the size to fit the next table values. std::size_t lookup_inputs_used = lookup_input_ptr->size(); @@ -526,7 +496,7 @@ namespace nil { // Do NOT capture converter by reference. parallel_for(0, gate.constraints.size(), - [&lookup_input_ptr, this, &gate, &lookup_selector, lookup_inputs_used, get_var_value](std::size_t index) { + [&lookup_input_ptr, this, &gate, &lookup_selector, lookup_inputs_used](std::size_t index) { // Create the converter. auto value_type_to_polynomial_dfs = [](const typename VariableType::assignment_type& coeff) { return polynomial_dfs_type(0, 1, coeff); @@ -534,11 +504,33 @@ namespace nil { math::expression_variable_type_converter converter(value_type_to_polynomial_dfs); const auto& constraint = gate.constraints[index]; - math::polynomial_dfs l = lookup_selector * (typename FieldType::value_type(constraint.table_id)); + polynomial_dfs_type l = lookup_selector * (typename FieldType::value_type(constraint.table_id)); + typename FieldType::value_type theta_acc = this->theta; for(std::size_t k = 0; k < constraint.lookup_input.size(); k++){ math::expression expr = converter.convert(constraint.lookup_input[k]); - math::cached_expression_evaluator evaluator(expr, get_var_value); + + // For each variable with a rotation pre-compute its value. + std::unordered_map rotated_variable_values; + + math::expression_for_each_variable_visitor visitor( + [&rotated_variable_values, &assignments=plonk_columns, &domain=basic_domain] + (const polynomial_dfs_variable_type& var) { + if (var.rotation == 0) + return; + rotated_variable_values[var] = assignments.get_variable_value(var, domain); + }); + visitor.visit(expr); + + math::cached_expression_evaluator evaluator(expr, + [&domain=basic_domain, &assignments=plonk_columns, &rotated_variable_values] + (const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& { + if (var.rotation == 0) { + return assignments.get_variable_value_without_rotation(var); + } + return rotated_variable_values[var]; + } + ); l += theta_acc * lookup_selector * evaluator.evaluate(); theta_acc *= this->theta; diff --git a/libs/parallel-zk/test/math/expression.cpp b/libs/parallel-zk/test/math/expression.cpp index 86f539e9..6d337a68 100644 --- a/libs/parallel-zk/test/math/expression.cpp +++ b/libs/parallel-zk/test/math/expression.cpp @@ -46,110 +46,116 @@ using namespace nil::crypto3::math; BOOST_AUTO_TEST_SUITE(expression_tests_suite) - BOOST_AUTO_TEST_CASE(expression_to_non_linear_combination_test) { - - // setup - using curve_type = algebra::curves::pallas; - using FieldType = typename curve_type::base_field_type; - using variable_type = typename nil::crypto3::zk::snark::plonk_variable; - - variable_type w0(0, 0, variable_type::column_type::witness); - variable_type w1(3, -1, variable_type::column_type::public_input); - variable_type w2(4, 1, variable_type::column_type::public_input); - variable_type w3(6, 2, variable_type::column_type::constant); - - expression expr = (w0 + w1) * (w2 + w3) - w1 * (w2 + w0); - - expression_to_non_linear_combination_visitor visitor; - non_linear_combination result = visitor.convert(expr); - non_linear_combination expected({w0 * w2, w0 * w3, w1 * w3, -w1 * w0}); - - // We may get the terms in a different order due to changes in the code, and that's fine. - BOOST_CHECK_EQUAL(result, expected); - } - - BOOST_AUTO_TEST_CASE(expression_evaluation_test) { - - // setup - using curve_type = algebra::curves::pallas; - using FieldType = typename curve_type::base_field_type; - using variable_type = typename nil::crypto3::zk::snark::plonk_variable; - - variable_type w0(0, 0, variable_type::column_type::witness); - variable_type w1(3, -1, variable_type::column_type::public_input); - variable_type w2(4, 1, variable_type::column_type::public_input); - variable_type w3(6, 2, variable_type::column_type::constant); - - expression expr = (w0 + w1) * (w2 + w3); - - expression_evaluator evaluator( - expr, - [&w0, &w1, &w2, &w3](const variable_type &var) { - if (var == w0) return variable_type::assignment_type(1u); - if (var == w1) return variable_type::assignment_type(2u); - if (var == w2) return variable_type::assignment_type(3u); - if (var == w3) return variable_type::assignment_type(4u); - return variable_type::assignment_type::zero(); - } - ); - - BOOST_CHECK(evaluator.evaluate() == variable_type::assignment_type((1u + 2u) * (3u + 4u))); - } - - BOOST_AUTO_TEST_CASE(expression_max_degree_visitor_test) { - - // setup - using curve_type = algebra::curves::pallas; - using FieldType = typename curve_type::base_field_type; - using variable_type = typename nil::crypto3::zk::snark::plonk_variable; - - variable_type w0(0, 0, variable_type::column_type::witness); - variable_type w1(3, -1, variable_type::column_type::public_input); - variable_type w2(4, 1, variable_type::column_type::public_input); - variable_type w3(6, 2, variable_type::column_type::constant); - - expression expr = (w0 + w1) * (w2 + w3) + w0 * w1 * (w2 + w3); - - expression_max_degree_visitor visitor; - - BOOST_CHECK_EQUAL(visitor.compute_max_degree(expr), 3); - } - - BOOST_AUTO_TEST_CASE(expression_for_each_variable_visitor_test) { - - // setup - using curve_type = algebra::curves::pallas; - using FieldType = typename curve_type::base_field_type; - using variable_type = typename nil::crypto3::zk::snark::plonk_variable; - - variable_type w0(0, 0, variable_type::column_type::witness); - variable_type w1(3, -1, variable_type::column_type::public_input); - variable_type w2(4, 1, variable_type::column_type::public_input); - variable_type w3(6, 2, variable_type::column_type::constant); - - expression expr = (w0 + w1) * (w2 + w3) + w0 * w1 * (w2 + w3); - - std::set variable_indices; - std::set variable_rotations; - - expression_for_each_variable_visitor visitor( - [&variable_indices, &variable_rotations](const variable_type &var) { - variable_indices.insert(var.index); - variable_rotations.insert(var.rotation); - } - ); - - visitor.visit(expr); - - std::set expected_indices = {0, 3, 4, 6}; - std::set expected_rotations = {0, -1, 1, 2}; - - BOOST_CHECK_EQUAL_COLLECTIONS( - variable_indices.begin(), variable_indices.end(), - expected_indices.begin(), expected_indices.end()); - BOOST_CHECK_EQUAL_COLLECTIONS( - variable_rotations.begin(), variable_rotations.end(), - expected_rotations.begin(), expected_rotations.end()); - } +BOOST_AUTO_TEST_CASE(expression_to_non_linear_combination_test) { + + // setup + using curve_type = algebra::curves::pallas; + using FieldType = typename curve_type::base_field_type; + using variable_type = typename nil::crypto3::zk::snark::plonk_variable; + + variable_type w0(0, 0, variable_type::column_type::witness); + variable_type w1(3, -1, variable_type::column_type::public_input); + variable_type w2(4, 1, variable_type::column_type::public_input); + variable_type w3(6, 2, variable_type::column_type::constant); + + expression expr = (w0 + w1) * (w2 + w3) - w1 * (w2 + w0); + + expression_to_non_linear_combination_visitor visitor; + non_linear_combination result = visitor.convert(expr); + non_linear_combination expected({w0 * w2, w0 * w3, w1 * w3, -w1 * w0}); + + // We may get the terms in a different order due to changes in the code, and that's fine. + BOOST_CHECK_EQUAL(result, expected); +} + +BOOST_AUTO_TEST_CASE(expression_evaluation_test) { + + // setup + using curve_type = algebra::curves::pallas; + using FieldType = typename curve_type::base_field_type; + using variable_type = typename nil::crypto3::zk::snark::plonk_variable; + + variable_type w0(0, 0, variable_type::column_type::witness); + variable_type w1(3, -1, variable_type::column_type::public_input); + variable_type w2(4, 1, variable_type::column_type::public_input); + variable_type w3(6, 2, variable_type::column_type::constant); + + expression expr = (w0 + w1) * (w2 + w3); + + variable_type::assignment_type w0_value(1u); + variable_type::assignment_type w1_value(2u); + variable_type::assignment_type w2_value(3u); + variable_type::assignment_type w3_value(4u); + expression_evaluator evaluator( + expr, + [&w0, &w1, &w2, &w3, &w0_value, &w1_value, &w2_value, &w3_value] + (const variable_type& var) -> const variable_type::assignment_type& { + if (var == w0) return w0_value; + if (var == w1) return w1_value; + if (var == w2) return w2_value; + if (var == w3) return w3_value; + std::cerr << "Variable not found" << std::endl; + abort(); + } + ); + + BOOST_CHECK(evaluator.evaluate() == variable_type::assignment_type((1u + 2u) * (3u + 4u))); +} + +BOOST_AUTO_TEST_CASE(expression_max_degree_visitor_test) { + + // setup + using curve_type = algebra::curves::pallas; + using FieldType = typename curve_type::base_field_type; + using variable_type = typename nil::crypto3::zk::snark::plonk_variable; + + variable_type w0(0, 0, variable_type::column_type::witness); + variable_type w1(3, -1, variable_type::column_type::public_input); + variable_type w2(4, 1, variable_type::column_type::public_input); + variable_type w3(6, 2, variable_type::column_type::constant); + + expression expr = (w0 + w1) * (w2 + w3) + w0 * w1 * (w2 + w3); + + expression_max_degree_visitor visitor; + + BOOST_CHECK_EQUAL(visitor.compute_max_degree(expr), 3); +} + +BOOST_AUTO_TEST_CASE(expression_for_each_variable_visitor_test) { + + // setup + using curve_type = algebra::curves::pallas; + using FieldType = typename curve_type::base_field_type; + using variable_type = typename nil::crypto3::zk::snark::plonk_variable; + + variable_type w0(0, 0, variable_type::column_type::witness); + variable_type w1(3, -1, variable_type::column_type::public_input); + variable_type w2(4, 1, variable_type::column_type::public_input); + variable_type w3(6, 2, variable_type::column_type::constant); + + expression expr = (w0 + w1) * (w2 + w3) + w0 * w1 * (w2 + w3); + + std::set variable_indices; + std::set variable_rotations; + + expression_for_each_variable_visitor visitor( + [&variable_indices, &variable_rotations](const variable_type& var) { + variable_indices.insert(var.index); + variable_rotations.insert(var.rotation); + } + ); + + visitor.visit(expr); + + std::set expected_indices = {0, 3, 4, 6}; + std::set expected_rotations = {0, -1, 1, 2}; + + BOOST_CHECK_EQUAL_COLLECTIONS( + variable_indices.begin(), variable_indices.end(), + expected_indices.begin(), expected_indices.end()); + BOOST_CHECK_EQUAL_COLLECTIONS( + variable_rotations.begin(), variable_rotations.end(), + expected_rotations.begin(), expected_rotations.end()); +} BOOST_AUTO_TEST_SUITE_END()