From 77aad9df700e18ecdb94b6a9d90d571da4eb114c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 5 Feb 2025 13:19:22 -0800 Subject: [PATCH] Draft: add support for cloning cel::Expr. PiperOrigin-RevId: 723626957 --- common/BUILD | 1 + common/expr.cc | 126 +++++++++++++++++++++ common/expr.h | 3 + extensions/protobuf/BUILD | 1 + extensions/protobuf/ast_converters_test.cc | 18 +++ 5 files changed, 149 insertions(+) diff --git a/common/BUILD b/common/BUILD index 551f39cd6..e2bce9288 100644 --- a/common/BUILD +++ b/common/BUILD @@ -33,6 +33,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", diff --git a/common/expr.cc b/common/expr.cc index 60fb97050..52a220e85 100644 --- a/common/expr.cc +++ b/common/expr.cc @@ -14,10 +14,124 @@ #include "common/expr.h" +#include + #include "absl/base/no_destructor.h" +#include "absl/functional/overload.h" +#include "absl/types/variant.h" +#include "common/constant.h" namespace cel { +namespace { + +struct CopyStackRecord { + const Expr* src; + Expr* dst; +}; + +void CopyNode(CopyStackRecord element, std::vector& stack) { + const Expr* src = element.src; + Expr* dst = element.dst; + dst->set_id(src->id()); + absl::visit( + absl::Overload( + [](const UnspecifiedExpr&) {}, + [=](const IdentExpr& i) { + dst->mutable_ident_expr().set_name(i.name()); + }, + [=](const Constant& c) { dst->mutable_const_expr() = c; }, + [&](const SelectExpr& s) { + dst->mutable_select_expr().set_field(s.field()); + dst->mutable_select_expr().set_test_only(s.test_only()); + + if (s.has_operand()) { + stack.push_back({&s.operand(), + &dst->mutable_select_expr().mutable_operand()}); + } + }, + [&](const CallExpr& c) { + dst->mutable_call_expr().set_function(c.function()); + if (c.has_target()) { + stack.push_back( + {&c.target(), &dst->mutable_call_expr().mutable_target()}); + } + dst->mutable_call_expr().mutable_args().resize(c.args().size()); + for (int i = 0; i < dst->mutable_call_expr().mutable_args().size(); + ++i) { + stack.push_back( + {&c.args()[i], &dst->mutable_call_expr().mutable_args()[i]}); + } + }, + [&](const ListExpr& c) { + auto& dst_list = dst->mutable_list_expr(); + dst_list.mutable_elements().resize(c.elements().size()); + for (int i = 0; i < src->list_expr().elements().size(); ++i) { + dst_list.mutable_elements()[i].set_optional( + c.elements()[i].optional()); + stack.push_back({&c.elements()[i].expr(), + &dst_list.mutable_elements()[i].mutable_expr()}); + } + }, + [&](const StructExpr& s) { + auto& dst_struct = dst->mutable_struct_expr(); + dst_struct.mutable_fields().resize(s.fields().size()); + dst_struct.set_name(s.name()); + for (int i = 0; i < s.fields().size(); ++i) { + dst_struct.mutable_fields()[i].set_optional( + s.fields()[i].optional()); + dst_struct.mutable_fields()[i].set_name(s.fields()[i].name()); + dst_struct.mutable_fields()[i].set_id(s.fields()[i].id()); + stack.push_back( + {&s.fields()[i].value(), + &dst_struct.mutable_fields()[i].mutable_value()}); + } + }, + [&](const MapExpr& c) { + auto& dst_map = dst->mutable_map_expr(); + dst_map.mutable_entries().resize(c.entries().size()); + for (int i = 0; i < c.entries().size(); ++i) { + dst_map.mutable_entries()[i].set_optional( + c.entries()[i].optional()); + dst_map.mutable_entries()[i].set_id(c.entries()[i].id()); + stack.push_back({&c.entries()[i].key(), + &dst_map.mutable_entries()[i].mutable_key()}); + stack.push_back({&c.entries()[i].value(), + &dst_map.mutable_entries()[i].mutable_value()}); + } + }, + [&](const ComprehensionExpr& c) { + auto& dst_comprehension = dst->mutable_comprehension_expr(); + dst_comprehension.set_iter_var(c.iter_var()); + dst_comprehension.set_iter_var2(c.iter_var2()); + dst_comprehension.set_accu_var(c.accu_var()); + if (c.has_accu_init()) { + stack.push_back( + {&c.accu_init(), &dst_comprehension.mutable_accu_init()}); + } + if (c.has_iter_range()) { + stack.push_back( + {&c.iter_range(), &dst_comprehension.mutable_iter_range()}); + } + if (c.has_loop_condition()) { + stack.push_back({&c.loop_condition(), + &dst_comprehension.mutable_loop_condition()}); + } + if (c.has_loop_step()) { + stack.push_back( + {&c.loop_step(), &dst_comprehension.mutable_loop_step()}); + } + if (c.has_result()) { + stack.push_back( + {&c.result(), &dst_comprehension.mutable_result()}); + } + } + + ), + src->kind()); +} +} // namespace + const UnspecifiedExpr& UnspecifiedExpr::default_instance() { static const absl::NoDestructor instance; return *instance; @@ -63,4 +177,16 @@ const Expr& Expr::default_instance() { return *instance; } +Expr CloneExpr(const Expr& expr) { + Expr result; + std::vector stack; + stack.push_back({&expr, &result}); + while (!stack.empty()) { + CopyStackRecord element = stack.back(); + stack.pop_back(); + CopyNode(element, stack); + } + return result; +} + } // namespace cel diff --git a/common/expr.h b/common/expr.h index 18828471f..e0bb07802 100644 --- a/common/expr.h +++ b/common/expr.h @@ -48,6 +48,9 @@ class ComprehensionExpr; inline constexpr absl::string_view kAccumulatorVariableName = "__result__"; +// Returns a deep copy of the given expression node. +Expr CloneExpr(const Expr& expr); + bool operator==(const Expr& lhs, const Expr& rhs); inline bool operator!=(const Expr& lhs, const Expr& rhs) { diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 7883384eb..14201648e 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -76,6 +76,7 @@ cc_test( ":ast_converters", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", + "//common:expr", "//internal:proto_matchers", "//internal:testing", "//parser", diff --git a/extensions/protobuf/ast_converters_test.cc b/extensions/protobuf/ast_converters_test.cc index 3cf01295f..8c2846f37 100644 --- a/extensions/protobuf/ast_converters_test.cc +++ b/extensions/protobuf/ast_converters_test.cc @@ -31,6 +31,7 @@ #include "absl/types/variant.h" #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" +#include "common/expr.h" #include "internal/proto_matchers.h" #include "internal/testing.h" #include "parser/options.h" @@ -801,6 +802,23 @@ TEST_P(ConversionRoundTripTest, ParsedExprCopyable) { IsOkAndHolds(EqualsProto(parsed_expr))); } +TEST_P(ConversionRoundTripTest, ExprClonable) { + ASSERT_OK_AND_ASSIGN(ParsedExprPb parsed_expr, + Parse(GetParam().expr, "", options_)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, + CreateAstFromParsedExpr(parsed_expr)); + + auto& impl = ast_internal::AstImpl::CastFromPublicAst(*ast); + impl.root_expr() = CloneExpr(impl.root_expr()); + + EXPECT_THAT(CreateCheckedExprFromAst(impl), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("AST is not type-checked"))); + EXPECT_THAT(CreateParsedExprFromAst(impl), + IsOkAndHolds(EqualsProto(parsed_expr))); +} + TEST_P(ConversionRoundTripTest, CheckedExprCopyable) { ASSERT_OK_AND_ASSIGN(ParsedExprPb parsed_expr, Parse(GetParam().expr, "", options_));