diff --git a/testing/testrunner/BUILD b/testing/testrunner/BUILD index 975b5884d..4bc8170c3 100644 --- a/testing/testrunner/BUILD +++ b/testing/testrunner/BUILD @@ -13,6 +13,7 @@ cc_library( hdrs = ["cel_test_context.h"], deps = [ ":cel_expression_source", + ":result_matcher", "//compiler", "//eval/public:cel_expression", "//runtime", @@ -31,6 +32,7 @@ cc_library( deps = [ ":cel_expression_source", ":cel_test_context", + ":result_matcher", "//checker:validation_result", "//common:ast", "//common:ast_proto", @@ -51,7 +53,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", - "@com_google_protobuf//:differencer", "@com_google_protobuf//:protobuf", ], ) @@ -80,6 +81,8 @@ cc_test( deps = [ ":cel_expression_source", ":cel_test_context", + ":default_result_matcher", + ":result_matcher", ":runner_lib", "//checker:type_checker_builder", "//checker:validation_result", @@ -135,3 +138,34 @@ cc_library( hdrs = ["cel_expression_source.h"], deps = ["@com_google_cel_spec//proto/cel/expr:checked_cc_proto"], ) + +cc_library( + name = "result_matcher", + hdrs = ["result_matcher.h"], + deps = [ + "//common:value", + "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "default_result_matcher", + srcs = ["default_result_matcher.cc"], + deps = [ + ":cel_test_context", + ":result_matcher", + "//common:value", + "//common/internal:value_conversion", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:differencer", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/testing/testrunner/cel_test_context.h b/testing/testrunner/cel_test_context.h index 335f25aa4..00bd654a6 100644 --- a/testing/testrunner/cel_test_context.h +++ b/testing/testrunner/cel_test_context.h @@ -29,8 +29,17 @@ #include "eval/public/cel_expression.h" #include "runtime/runtime.h" #include "testing/testrunner/cel_expression_source.h" +#include "testing/testrunner/result_matcher.h" + namespace cel::test { +// Factory function for creating a `ResultMatcher` with default settings. +// +// This is used by `CelTestContext` when a custom result matcher is not +// provided in `CelTestContextOptions`. It ensures that a default matcher is +// always available for performing assertions in tests. +std::unique_ptr CreateDefaultResultMatcher(); + // Struct to hold optional parameters for `CelTestContext`. struct CelTestContextOptions { // The source for the CEL expression to be evaluated in the test. @@ -50,6 +59,10 @@ struct CelTestContextOptions { // This logic is handled by the test runner when it constructs the final // activation. absl::flat_hash_map custom_bindings; + + // An optional result matcher to be used for assertions. If not provided, a + // default result matcher will be used. + std::unique_ptr result_matcher = nullptr; }; // The context class for a CEL test, holding configurations needed to evaluate @@ -115,6 +128,9 @@ class CelTestContext { return cel_test_context_options_.custom_bindings; } + // Returns the result matcher to be used for assertions. + const ResultMatcher& result_matcher() const { return *result_matcher_; } + private: // Delete copy and move constructors. CelTestContext(const CelTestContext&) = delete; @@ -128,12 +144,24 @@ class CelTestContext { cel_expression_builder, CelTestContextOptions options) : cel_test_context_options_(std::move(options)), - cel_expression_builder_(std::move(cel_expression_builder)) {} + cel_expression_builder_(std::move(cel_expression_builder)) { + if (cel_test_context_options_.result_matcher) { + result_matcher_ = std::move(cel_test_context_options_.result_matcher); + } else { + result_matcher_ = CreateDefaultResultMatcher(); + } + } CelTestContext(std::unique_ptr runtime, CelTestContextOptions options) : cel_test_context_options_(std::move(options)), - runtime_(std::move(runtime)) {} + runtime_(std::move(runtime)) { + if (cel_test_context_options_.result_matcher) { + result_matcher_ = std::move(cel_test_context_options_.result_matcher); + } else { + result_matcher_ = CreateDefaultResultMatcher(); + } + } // Configuration for the expression to be executed. CelTestContextOptions cel_test_context_options_; @@ -148,6 +176,9 @@ class CelTestContext { // needed to generate Program. Users should either provide a runtime, or the // CelExpressionBuilder. std::unique_ptr runtime_; + + // The result matcher to be used for assertions. + std::unique_ptr result_matcher_; }; } // namespace cel::test diff --git a/testing/testrunner/default_result_matcher.cc b/testing/testrunner/default_result_matcher.cc new file mode 100644 index 000000000..255bec44e --- /dev/null +++ b/testing/testrunner/default_result_matcher.cc @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "cel/expr/eval.pb.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "common/internal/value_conversion.h" +#include "common/value.h" +#include "internal/testing.h" +#include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/result_matcher.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/field_comparator.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel::test { + +namespace { + +using ValueProto = ::cel::expr::Value; +using ::cel::expr::conformance::test::TestOutput; + +bool IsEqual(const ValueProto& expected, const ValueProto& actual) { + static auto* kFieldComparator = []() { + auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); + field_comparator->set_treat_nan_as_equal(true); + return field_comparator; + }(); + static auto* kDifferencer = []() { + auto* differencer = new google::protobuf::util::MessageDifferencer(); + differencer->set_message_field_comparison( + google::protobuf::util::MessageDifferencer::EQUIVALENT); + differencer->set_field_comparator(kFieldComparator); + const auto* descriptor = cel::expr::MapValue::descriptor(); + const auto* entries_field = descriptor->FindFieldByName("entries"); + const auto* key_field = + entries_field->message_type()->FindFieldByName("key"); + differencer->TreatAsMap(entries_field, key_field); + return differencer; + }(); + return kDifferencer->Compare(expected, actual); +} + +MATCHER_P(MatchesValue, expected, "") { return IsEqual(arg, expected); } + +class DefaultResultMatcher : public cel::test::ResultMatcher { + public: + void Match(const ResultMatcherParams& params) const override { + const TestOutput& output = params.expected_output; + const auto& computed_output = params.computed_output; + google::protobuf::Arena* arena = params.arena; + + if (output.has_result_value()) { + AssertValue(computed_output, output, params.test_context, arena); + } else if (output.has_eval_error()) { + AssertError(computed_output, output); + } else if (output.has_unknown()) { + ADD_FAILURE() << "Unknown assertions not implemented yet."; + } else { + ADD_FAILURE() << "Unexpected output kind."; + } + } + + private: + void AssertValue(const cel::Value& computed, const TestOutput& output, + const CelTestContext& test_context, + google::protobuf::Arena* arena) const { + ValueProto expected_value_proto; + const auto* descriptor_pool = + test_context.runtime() != nullptr + ? test_context.runtime()->GetDescriptorPool() + : google::protobuf::DescriptorPool::generated_pool(); + auto* message_factory = test_context.runtime() != nullptr + ? test_context.runtime()->GetMessageFactory() + : google::protobuf::MessageFactory::generated_factory(); + + ValueProto computed_expr_value; + ASSERT_OK_AND_ASSIGN( + computed_expr_value, + ToExprValue(computed, descriptor_pool, message_factory, arena)); + EXPECT_THAT(output.result_value(), MatchesValue(computed_expr_value)); + } + + void AssertError(const cel::Value& computed, const TestOutput& output) const { + if (!computed.IsError()) { + ADD_FAILURE() << "Expected error but got value: " + << computed.DebugString(); + return; + } + absl::Status computed_status = computed.AsError()->ToStatus(); + // We selected the first error in the set for comparison because there is + // only one runtime error that is reported even if there are multiple errors + // in the critical path. + ASSERT_TRUE(output.eval_error().errors_size() == 1) + << "Expected exactly one error but got: " + << output.eval_error().errors_size(); + ASSERT_EQ(computed_status.message(), + output.eval_error().errors(0).message()); + } +}; +} // namespace + +std::unique_ptr CreateDefaultResultMatcher() { + return std::make_unique(); +} +} // namespace cel::test diff --git a/testing/testrunner/result_matcher.h b/testing/testrunner/result_matcher.h new file mode 100644 index 000000000..ad4a7d422 --- /dev/null +++ b/testing/testrunner/result_matcher.h @@ -0,0 +1,44 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RESULT_MATCHER_H_ +#define THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RESULT_MATCHER_H_ + +#include "common/value.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" + +namespace cel::test { + +// Forward declare CelTestContext to avoid circular includes. +class CelTestContext; + +// Parameters passed to the ResultMatcher for performing assertions. +struct ResultMatcherParams { + const cel::expr::conformance::test::TestOutput& expected_output; + const CelTestContext& test_context; + const cel::Value& computed_output; + google::protobuf::Arena* arena; +}; + +// Interface for a custom result matcher. +class ResultMatcher { + public: + virtual ~ResultMatcher() = default; + virtual void Match(const ResultMatcherParams& params) const = 0; +}; + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTING_TESTRUNNER_RESULT_MATCHER_H_ diff --git a/testing/testrunner/runner_lib.cc b/testing/testrunner/runner_lib.cc index 2b0375f91..f9cb64951 100644 --- a/testing/testrunner/runner_lib.cc +++ b/testing/testrunner/runner_lib.cc @@ -41,12 +41,11 @@ #include "runtime/runtime.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/result_matcher.h" #include "cel/expr/conformance/test/suite.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" -#include "google/protobuf/util/field_comparator.h" -#include "google/protobuf/util/message_differencer.h" namespace cel::test { namespace { @@ -262,83 +261,33 @@ absl::StatusOr CreateLegacyActivationFromBindings( return activation; } -bool IsEqual(const ValueProto& expected, const ValueProto& actual) { - static auto* kFieldComparator = []() { - auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); - field_comparator->set_treat_nan_as_equal(true); - return field_comparator; - }(); - static auto* kDifferencer = []() { - auto* differencer = new google::protobuf::util::MessageDifferencer(); - differencer->set_message_field_comparison( - google::protobuf::util::MessageDifferencer::EQUIVALENT); - differencer->set_field_comparator(kFieldComparator); - const auto* descriptor = cel::expr::MapValue::descriptor(); - const auto* entries_field = descriptor->FindFieldByName("entries"); - const auto* key_field = - entries_field->message_type()->FindFieldByName("key"); - differencer->TreatAsMap(entries_field, key_field); - return differencer; - }(); - return kDifferencer->Compare(expected, actual); -} +absl::StatusOr ResolveResultExpr(const TestOutput& test_output, + const CelTestContext& test_context, + google::protobuf::Arena* arena) { + TestOutput updated_output = test_output; + if (!updated_output.has_result_expr()) { + return updated_output; + } -MATCHER_P(MatchesValue, expected, "") { return IsEqual(arg, expected); } -} // namespace + InputValue input_value; + input_value.set_expr(updated_output.result_expr()); -void TestRunner::AssertValue(const cel::Value& computed, - const TestOutput& output, google::protobuf::Arena* arena) { - ValueProto expected_value_proto; - const auto* descriptor_pool = GetDescriptorPool(*test_context_); - auto* message_factory = GetMessageFactory(*test_context_); - if (output.has_result_value()) { - expected_value_proto = output.result_value(); - } else if (output.has_result_expr()) { - InputValue input_value; - input_value.set_expr(output.result_expr()); - ASSERT_OK_AND_ASSIGN(cel::Value resolved_cel_value, - ResolveInputValue(input_value, *test_context_, arena)); - ASSERT_OK_AND_ASSIGN(expected_value_proto, - ToExprValue(resolved_cel_value, descriptor_pool, - message_factory, arena)); - } - ValueProto computed_expr_value; - ASSERT_OK_AND_ASSIGN( - computed_expr_value, - ToExprValue(computed, descriptor_pool, message_factory, arena)); - EXPECT_THAT(expected_value_proto, MatchesValue(computed_expr_value)); -} + CEL_ASSIGN_OR_RETURN(cel::Value resolved_value, + ResolveInputValue(input_value, test_context, arena)); -void TestRunner::AssertError(const cel::Value& computed, - const TestOutput& output) { - if (!computed.IsError()) { - ADD_FAILURE() << "Expected error but got value: " << computed.DebugString(); - return; - } - absl::Status computed_status = computed.AsError()->ToStatus(); - // We selected the first error in the set for comparison because there is only - // one runtime error that is reported even if there are multiple errors in the - // critical path. - ASSERT_TRUE(output.eval_error().errors_size() == 1) - << "Expected exactly one error but got: " - << output.eval_error().errors_size(); - ASSERT_EQ(computed_status.message(), output.eval_error().errors(0).message()); -} + const auto* descriptor_pool = GetDescriptorPool(test_context); + auto* message_factory = GetMessageFactory(test_context); -void TestRunner::Assert(const cel::Value& computed, const TestCase& test_case, - google::protobuf::Arena* arena) { - TestOutput output = test_case.output(); - if (output.has_result_value() || output.has_result_expr()) { - AssertValue(computed, output, arena); - } else if (output.has_eval_error()) { - AssertError(computed, output); - } else if (output.has_unknown()) { - ADD_FAILURE() << "Unknown assertions not implemented yet."; - } else { - ADD_FAILURE() << "Unexpected output kind."; - } + updated_output.clear_result_expr(); + CEL_ASSIGN_OR_RETURN( + *updated_output.mutable_result_value(), + ToExprValue(resolved_value, descriptor_pool, message_factory, arena)); + + return updated_output; } +} // namespace + absl::StatusOr TestRunner::EvalWithRuntime( const CheckedExpr& checked_expr, const TestCase& test_case, google::protobuf::Arena* arena) { @@ -385,16 +334,21 @@ void TestRunner::RunTest(const TestCase& test_case) { // EvalWithRuntime or EvalWithCelExpressionBuilder might contain pointers to // the arena. The arena has to be alive during the assertion. google::protobuf::Arena arena; + cel::Value computed_output; + ASSERT_OK_AND_ASSIGN( + TestOutput resolved_output, + ResolveResultExpr(test_case.output(), *test_context_, &arena)); ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr, GetCheckedExpr()); if (test_context_->runtime() != nullptr) { - ASSERT_OK_AND_ASSIGN(cel::Value result, + ASSERT_OK_AND_ASSIGN(computed_output, EvalWithRuntime(checked_expr, test_case, &arena)); - ASSERT_NO_FATAL_FAILURE(Assert(result, test_case, &arena)); } else if (test_context_->cel_expression_builder() != nullptr) { - ASSERT_OK_AND_ASSIGN( - cel::Value result, - EvalWithCelExpressionBuilder(checked_expr, test_case, &arena)); - ASSERT_NO_FATAL_FAILURE(Assert(result, test_case, &arena)); + ASSERT_OK_AND_ASSIGN(computed_output, EvalWithCelExpressionBuilder( + checked_expr, test_case, &arena)); } + + ResultMatcherParams params{resolved_output, *test_context_, computed_output, + &arena}; + test_context_->result_matcher().Match(params); } } // namespace cel::test diff --git a/testing/testrunner/runner_lib_test.cc b/testing/testrunner/runner_lib_test.cc index f63952b2c..886f6bf03 100644 --- a/testing/testrunner/runner_lib_test.cc +++ b/testing/testrunner/runner_lib_test.cc @@ -43,6 +43,7 @@ #include "runtime/standard_runtime_builder_factory.h" #include "testing/testrunner/cel_expression_source.h" #include "testing/testrunner/cel_test_context.h" +#include "testing/testrunner/result_matcher.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -129,6 +130,13 @@ class TestRunnerParamTest : public ::testing::TestWithParam { } }; +class CustomFailMatcher : public cel::test::ResultMatcher { + public: + void Match(const ResultMatcherParams& params) const override { + ADD_FAILURE() << "This test failed because the CUSTOM MATCHER ran!"; + } +}; + TEST_P(TestRunnerParamTest, BasicTestReportsSuccess) { ASSERT_OK_AND_ASSIGN( cel::ValidationResult validation_result, @@ -418,6 +426,31 @@ TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsReportsFailure) { "int64_value: 15"); // expected 15 got 999. } +TEST_P(TestRunnerParamTest, CustomMatcherIsUsedWhenProvided) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("true")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { result_value { bool_value: false } } + )pb"); + + // Create the options and pass an instance of our custom matcher. + CelTestContextOptions options; + options.expression_source = + CelExpressionSource::FromCheckedExpr(std::move(checked_expr)); + options.result_matcher = std::make_unique(); + + // Create the context and runner using our options. + ASSERT_OK_AND_ASSIGN(auto context, CreateTestContext(std::move(options))); + TestRunner test_runner(std::move(context)); + + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "This test failed because the CUSTOM MATCHER ran!"); +} + INSTANTIATE_TEST_SUITE_P(TestRunnerTests, TestRunnerParamTest, ::testing::Values(RuntimeApi::kRuntime, RuntimeApi::kBuilder)); diff --git a/testing/testrunner/user_tests/BUILD b/testing/testrunner/user_tests/BUILD index 436176f1c..6f571234f 100644 --- a/testing/testrunner/user_tests/BUILD +++ b/testing/testrunner/user_tests/BUILD @@ -24,6 +24,7 @@ cc_library( "//testing/testrunner:cel_expression_source", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_factories", + "//testing/testrunner:default_result_matcher", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status:statusor", @@ -53,6 +54,7 @@ cc_library( "//testing/testrunner:cel_expression_source", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_factories", + "//testing/testrunner:default_result_matcher", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view",