|
18 | 18 | #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ |
19 | 19 | #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ |
20 | 20 |
|
| 21 | +#include <cstddef> |
21 | 22 | #include <functional> |
22 | 23 | #include <memory> |
| 24 | +#include <tuple> |
| 25 | +#include <utility> |
23 | 26 | #include <vector> |
24 | 27 |
|
25 | 28 | #include "absl/base/nullability.h" |
26 | 29 | #include "absl/functional/any_invocable.h" |
27 | | -#include "absl/functional/bind_front.h" |
28 | 30 | #include "absl/status/status.h" |
29 | 31 | #include "absl/status/statusor.h" |
30 | 32 | #include "absl/strings/str_cat.h" |
31 | 33 | #include "absl/strings/string_view.h" |
32 | 34 | #include "absl/types/span.h" |
33 | 35 | #include "common/function_descriptor.h" |
34 | | -#include "common/kind.h" |
35 | 36 | #include "common/value.h" |
36 | 37 | #include "internal/status_macros.h" |
37 | 38 | #include "runtime/function.h" |
@@ -94,79 +95,73 @@ struct AdaptedTypeTraits<const T&> { |
94 | 95 | static T ToArg(AssignableType v) { return v; } |
95 | 96 | }; |
96 | 97 |
|
97 | | -template <typename... Args> |
98 | | -struct KindAdderImpl; |
99 | | - |
100 | | -template <typename Arg, typename... Args> |
101 | | -struct KindAdderImpl<Arg, Args...> { |
102 | | - static void AddTo(std::vector<cel::Kind>& args) { |
103 | | - args.push_back(AdaptedKind<Arg>()); |
104 | | - KindAdderImpl<Args...>::AddTo(args); |
| 98 | +template <size_t I, typename... Args> |
| 99 | +struct AdaptHelperImpl { |
| 100 | + template <typename T> |
| 101 | + static absl::Status Apply(absl::Span<const Value> input, T& output) { |
| 102 | + static_assert(sizeof...(Args) > 0); |
| 103 | + static_assert(std::tuple_size_v<T> == sizeof...(Args)); |
| 104 | + CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[I]}(&std::get<I>(output))); |
| 105 | + if constexpr (I == sizeof...(Args) - 1) { |
| 106 | + return absl::OkStatus(); |
| 107 | + } else { |
| 108 | + CEL_RETURN_IF_ERROR( |
| 109 | + (AdaptHelperImpl<I + 1, Args...>::template Apply<T>(input, output))); |
| 110 | + } |
| 111 | + return absl::OkStatus(); |
105 | 112 | } |
106 | 113 | }; |
107 | 114 |
|
108 | | -template <> |
109 | | -struct KindAdderImpl<> { |
110 | | - static void AddTo(std::vector<cel::Kind>& args) {} |
111 | | -}; |
112 | | - |
113 | 115 | template <typename... Args> |
114 | | -struct KindAdder { |
115 | | - static std::vector<cel::Kind> Kinds() { |
116 | | - std::vector<cel::Kind> args; |
117 | | - KindAdderImpl<Args...>::AddTo(args); |
118 | | - return args; |
| 116 | +struct AdaptHelper { |
| 117 | + template <typename T> |
| 118 | + static absl::Status Apply(absl::Span<const Value> input, T& output) { |
| 119 | + return AdaptHelperImpl<0, Args...>::template Apply<T>(input, output); |
119 | 120 | } |
120 | 121 | }; |
121 | 122 |
|
122 | | -template <typename T> |
123 | | -struct ApplyReturnType { |
124 | | - using type = absl::StatusOr<T>; |
125 | | -}; |
126 | | - |
127 | | -template <typename T> |
128 | | -struct ApplyReturnType<absl::StatusOr<T>> { |
129 | | - using type = absl::StatusOr<T>; |
130 | | -}; |
131 | | - |
132 | | -template <int N, typename Arg, typename... Args> |
133 | | -struct IndexerImpl { |
134 | | - using type = typename IndexerImpl<N - 1, Args...>::type; |
135 | | -}; |
136 | | - |
137 | | -template <typename Arg, typename... Args> |
138 | | -struct IndexerImpl<0, Arg, Args...> { |
139 | | - using type = Arg; |
140 | | -}; |
| 123 | +template <typename... Args> |
| 124 | +struct ToArgsImpl { |
| 125 | + template <int I, typename T> |
| 126 | + struct El { |
| 127 | + using type = T; |
| 128 | + constexpr static size_t index = I; |
| 129 | + }; |
141 | 130 |
|
142 | | -template <int N, typename... Args> |
143 | | -struct Indexer { |
144 | | - static_assert(N < sizeof...(Args) && N >= 0); |
145 | | - using type = typename IndexerImpl<N, Args...>::type; |
146 | | -}; |
| 131 | + template <typename... Es> |
| 132 | + struct ZipHolder { |
| 133 | + template <typename ResultType, typename TupleType, typename Op> |
| 134 | + static ResultType ToArgs( |
| 135 | + Op&& op, const TupleType& argbuffer, |
| 136 | + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, |
| 137 | + google::protobuf::MessageFactory* absl_nonnull message_factory, |
| 138 | + google::protobuf::Arena* absl_nonnull arena) { |
| 139 | + return std::forward<Op>(op)( |
| 140 | + runtime_internal::AdaptedTypeTraits<typename Es::type>::ToArg( |
| 141 | + std::get<Es::index>(argbuffer))..., |
| 142 | + descriptor_pool, message_factory, arena); |
| 143 | + } |
| 144 | + }; |
147 | 145 |
|
148 | | -template <int N, typename... Args> |
149 | | -struct ApplyHelper { |
150 | | - template <typename T, typename Op> |
151 | | - static typename ApplyReturnType<T>::type Apply( |
152 | | - Op&& op, absl::Span<const Value> input) { |
153 | | - constexpr int idx = sizeof...(Args) - N; |
154 | | - using Arg = typename Indexer<idx, Args...>::type; |
155 | | - using ArgTraits = AdaptedTypeTraits<Arg>; |
156 | | - typename ArgTraits::AssignableType arg_i; |
157 | | - CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[idx]}(&arg_i)); |
158 | | - |
159 | | - return ApplyHelper<N - 1, Args...>::template Apply<T>( |
160 | | - absl::bind_front(std::forward<Op>(op), ArgTraits::ToArg(arg_i)), input); |
| 146 | + template <size_t... Is> |
| 147 | + static ZipHolder<El<Is, Args>...> MakeZip(const std::index_sequence<Is...>&) { |
| 148 | + return ZipHolder<El<Is, Args>...>{}; |
161 | 149 | } |
162 | 150 | }; |
163 | 151 |
|
164 | 152 | template <typename... Args> |
165 | | -struct ApplyHelper<0, Args...> { |
166 | | - template <typename T, typename Op> |
167 | | - static typename ApplyReturnType<T>::type Apply( |
168 | | - Op&& op, absl::Span<const Value> input) { |
169 | | - return op(); |
| 153 | +struct ToArgsHelper { |
| 154 | + template <typename ResultType, typename TupleType, typename Op> |
| 155 | + static ResultType Apply( |
| 156 | + Op&& op, const TupleType& argbuffer, |
| 157 | + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, |
| 158 | + google::protobuf::MessageFactory* absl_nonnull message_factory, |
| 159 | + google::protobuf::Arena* absl_nonnull arena) { |
| 160 | + using Impl = ToArgsImpl<Args...>; |
| 161 | + using Zip = decltype(Impl::MakeZip(std::index_sequence_for<Args...>{})); |
| 162 | + return Zip::template ToArgs<ResultType>(std::forward<Op>(op), argbuffer, |
| 163 | + descriptor_pool, message_factory, |
| 164 | + arena); |
170 | 165 | } |
171 | 166 | }; |
172 | 167 |
|
@@ -629,6 +624,98 @@ class QuaternaryFunctionAdapter |
629 | 624 | }; |
630 | 625 | }; |
631 | 626 |
|
| 627 | +// Primary template for n-ary adapter. |
| 628 | +template <typename T, typename... Args> |
| 629 | +class NaryFunctionAdapter; |
| 630 | + |
| 631 | +template <typename T> |
| 632 | +class NaryFunctionAdapter<T> : public NullaryFunctionAdapter<T> {}; |
| 633 | + |
| 634 | +template <typename T, typename U> |
| 635 | +class NaryFunctionAdapter<T, U> : public UnaryFunctionAdapter<T, U> {}; |
| 636 | + |
| 637 | +template <typename T, typename U, typename V> |
| 638 | +class NaryFunctionAdapter<T, U, V> : public BinaryFunctionAdapter<T, U, V> {}; |
| 639 | + |
| 640 | +template <typename T, typename U, typename V, typename W> |
| 641 | +class NaryFunctionAdapter<T, U, V, W> |
| 642 | + : public TernaryFunctionAdapter<T, U, V, W> {}; |
| 643 | + |
| 644 | +template <typename T, typename U, typename V, typename W, typename X> |
| 645 | +class NaryFunctionAdapter<T, U, V, W, X> |
| 646 | + : public QuaternaryFunctionAdapter<T, U, V, W, X> {}; |
| 647 | + |
| 648 | +// N-ary function adapter. |
| 649 | +// |
| 650 | +// Prefer using one of the specific count adapters above for readability and |
| 651 | +// better error messages. |
| 652 | +template <typename T, typename... Args> |
| 653 | +class NaryFunctionAdapter |
| 654 | + : public RegisterHelper<NaryFunctionAdapter<T, Args...>> { |
| 655 | + public: |
| 656 | + using FunctionType = absl::AnyInvocable<T( |
| 657 | + Args..., const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, |
| 658 | + google::protobuf::MessageFactory* absl_nonnull message_factory, |
| 659 | + google::protobuf::Arena* absl_nonnull arena) const>; |
| 660 | + |
| 661 | + static FunctionDescriptor CreateDescriptor(absl::string_view name, |
| 662 | + bool receiver_style, |
| 663 | + bool is_strict = true) { |
| 664 | + return FunctionDescriptor(name, receiver_style, |
| 665 | + {runtime_internal::AdaptedKind<Args>()...}, |
| 666 | + is_strict); |
| 667 | + } |
| 668 | + |
| 669 | + static std::unique_ptr<cel::Function> WrapFunction(FunctionType fn) { |
| 670 | + return std::make_unique<NaryFunctionImpl>(std::move(fn)); |
| 671 | + } |
| 672 | + |
| 673 | + static std::unique_ptr<cel::Function> WrapFunction( |
| 674 | + absl::AnyInvocable<T(Args...) const> function) { |
| 675 | + return WrapFunction( |
| 676 | + [function = std::move(function)]( |
| 677 | + Args... args, const google::protobuf::DescriptorPool* absl_nonnull, |
| 678 | + google::protobuf::MessageFactory* absl_nonnull, |
| 679 | + google::protobuf::Arena* absl_nonnull) -> T { return function(args...); }); |
| 680 | + } |
| 681 | + |
| 682 | + private: |
| 683 | + class NaryFunctionImpl : public cel::Function { |
| 684 | + private: |
| 685 | + using ArgBuffer = std::tuple< |
| 686 | + typename runtime_internal::AdaptedTypeTraits<Args>::AssignableType...>; |
| 687 | + |
| 688 | + public: |
| 689 | + explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} |
| 690 | + absl::StatusOr<Value> Invoke( |
| 691 | + absl::Span<const Value> args, |
| 692 | + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, |
| 693 | + google::protobuf::MessageFactory* absl_nonnull message_factory, |
| 694 | + google::protobuf::Arena* absl_nonnull arena) const override { |
| 695 | + if (args.size() != sizeof...(Args)) { |
| 696 | + return absl::InvalidArgumentError( |
| 697 | + absl::StrCat("unexpected number of arguments for ", sizeof...(Args), |
| 698 | + "-ary function")); |
| 699 | + } |
| 700 | + ArgBuffer arg_buffer; |
| 701 | + CEL_RETURN_IF_ERROR( |
| 702 | + runtime_internal::AdaptHelper<Args...>::Apply(args, arg_buffer)); |
| 703 | + if constexpr (std::is_same_v<T, Value> || |
| 704 | + std::is_same_v<T, absl::StatusOr<Value>>) { |
| 705 | + return runtime_internal::ToArgsHelper<Args...>::template Apply<T>( |
| 706 | + fn_, arg_buffer, descriptor_pool, message_factory, arena); |
| 707 | + } else { |
| 708 | + T result = runtime_internal::ToArgsHelper<Args...>::template Apply<T>( |
| 709 | + fn_, arg_buffer, descriptor_pool, message_factory, arena); |
| 710 | + return runtime_internal::AdaptedToHandleVisitor{}(std::move(result)); |
| 711 | + } |
| 712 | + } |
| 713 | + |
| 714 | + private: |
| 715 | + FunctionType fn_; |
| 716 | + }; |
| 717 | +}; |
| 718 | + |
632 | 719 | } // namespace cel |
633 | 720 |
|
634 | 721 | #endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ |
0 commit comments