Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ class FlatExprVisitor : public cel::AstVisitor {
resolved_select_expr_(nullptr),
options_(options),
program_optimizers_(std::move(program_optimizers)),
reference_map_(reference_map),
issue_collector_(issue_collector),
program_builder_(program_builder),
extension_context_(extension_context),
Expand Down Expand Up @@ -1670,6 +1671,15 @@ class FlatExprVisitor : public cel::AstVisitor {
suppressed_branches_.insert(expr);
}

const cel::Reference& FindReference(const cel::Expr* expr) const {
auto it = reference_map_.find(expr->id());
if (it == reference_map_.end()) {
static const cel::Reference no_reference;
return no_reference;
}
return it->second;
}

void AddResolvedFunctionStep(const cel::CallExpr* call_expr,
const cel::Expr* expr,
absl::string_view function) {
Expand All @@ -1687,12 +1697,14 @@ class FlatExprVisitor : public cel::AstVisitor {
auto args = program_builder_.current()->ExtractRecursiveDependencies();
SetRecursiveStep(CreateDirectLazyFunctionStep(
expr->id(), *call_expr, std::move(args),
std::move(lazy_overloads)),
std::move(lazy_overloads),
FindReference(expr).overload_id()),
*depth + 1);
return;
}
AddStep(CreateFunctionStep(*call_expr, expr->id(),
std::move(lazy_overloads)));
std::move(lazy_overloads),
FindReference(expr).overload_id()));
return;
}

Expand Down Expand Up @@ -1721,11 +1733,14 @@ class FlatExprVisitor : public cel::AstVisitor {
auto args = program_builder_.current()->ExtractRecursiveDependencies();
SetRecursiveStep(
CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args),
std::move(overloads)),
std::move(overloads),
FindReference(expr).overload_id()),
*recursion_depth + 1);
return;
}
AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads)));
AddStep(CreateFunctionStep(*call_expr, expr->id(),
std::move(overloads),
FindReference(expr).overload_id()));
}

// Add a step to the program, taking ownership. If successful, returns the
Expand Down Expand Up @@ -1963,6 +1978,7 @@ class FlatExprVisitor : public cel::AstVisitor {
absl::flat_hash_set<const cel::Expr*> suppressed_branches_;
const cel::Expr* resume_from_suppressed_branch_ = nullptr;
std::vector<std::unique_ptr<ProgramOptimizer>> program_optimizers_;
const absl::flat_hash_map<int64_t, cel::Reference>& reference_map_;
IssueCollector& issue_collector_;

ProgramBuilder& program_builder_;
Expand Down
3 changes: 2 additions & 1 deletion eval/compiler/flat_expr_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2397,7 +2397,8 @@ class UnknownFunctionImpl : public cel::Function {
absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull,
google::protobuf::MessageFactory* absl_nonnull,
google::protobuf::Arena* absl_nonnull) const override {
google::protobuf::Arena* absl_nonnull,
absl::Span<const std::string> overload_id) const override {
return cel::UnknownValue();
}
};
Expand Down
56 changes: 37 additions & 19 deletions eval/eval/function_step.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ class AbstractFunctionStep : public ExpressionStepBase {
public:
// Constructs FunctionStep that uses overloads specified.
AbstractFunctionStep(const std::string& name, size_t num_arguments,
int64_t expr_id)
int64_t expr_id, std::vector<std::string>&& overload_id)
: ExpressionStepBase(expr_id),
name_(name),
num_arguments_(num_arguments) {}
num_arguments_(num_arguments),
overload_id_(std::move(overload_id)) {}

absl::Status Evaluate(ExecutionFrame* frame) const override;

Expand All @@ -172,15 +173,18 @@ class AbstractFunctionStep : public ExpressionStepBase {
protected:
std::string name_;
size_t num_arguments_;
std::vector<std::string> overload_id_;
};

inline absl::StatusOr<Value> Invoke(
const cel::FunctionOverloadReference& overload, int64_t expr_id,
absl::Span<const cel::Value> args, ExecutionFrameBase& frame) {
absl::Span<const cel::Value> args, ExecutionFrameBase& frame,
absl::Span<const std::string> overload_id) {
CEL_ASSIGN_OR_RETURN(
Value result,
overload.implementation.Invoke(args, frame.descriptor_pool(),
frame.message_factory(), frame.arena()));
frame.message_factory(), frame.arena(),
overload_id));

if (frame.unknown_function_results_enabled() &&
IsUnknownFunctionResultError(result)) {
Expand Down Expand Up @@ -240,7 +244,7 @@ absl::StatusOr<Value> AbstractFunctionStep::DoEvaluate(
// Overload found and is allowed to consume the arguments.
if (matched_function.has_value() &&
ShouldAcceptOverload(matched_function->descriptor, input_args)) {
return Invoke(*matched_function, id(), input_args, *frame);
return Invoke(*matched_function, id(), input_args, *frame, overload_id_);
}

return NoOverloadResult(name_, input_args, *frame);
Expand Down Expand Up @@ -323,8 +327,9 @@ absl::StatusOr<ResolveResult> ResolveLazy(
class EagerFunctionStep : public AbstractFunctionStep {
public:
EagerFunctionStep(std::vector<cel::FunctionOverloadReference> overloads,
const std::string& name, size_t num_args, int64_t expr_id)
: AbstractFunctionStep(name, num_args, expr_id),
const std::string& name, size_t num_args, int64_t expr_id,
std::vector<std::string>&& overload_id)
: AbstractFunctionStep(name, num_args, expr_id, std::move(overload_id)),
overloads_(std::move(overloads)) {}

absl::StatusOr<ResolveResult> ResolveFunction(
Expand All @@ -344,8 +349,9 @@ class LazyFunctionStep : public AbstractFunctionStep {
LazyFunctionStep(const std::string& name, size_t num_args,
bool receiver_style,
std::vector<cel::FunctionRegistry::LazyOverload> providers,
int64_t expr_id)
: AbstractFunctionStep(name, num_args, expr_id),
int64_t expr_id,
std::vector<std::string>&& overload_id)
: AbstractFunctionStep(name, num_args, expr_id, std::move(overload_id)),
receiver_style_(receiver_style),
providers_(std::move(providers)) {}

Expand Down Expand Up @@ -404,10 +410,12 @@ class DirectFunctionStepImpl : public DirectExpressionStep {
DirectFunctionStepImpl(
int64_t expr_id, const std::string& name,
std::vector<std::unique_ptr<DirectExpressionStep>> arg_steps,
Resolver&& resolver)
Resolver&& resolver,
std::vector<std::string>&& overload_id)
: DirectExpressionStep(expr_id),
name_(name),
arg_steps_(std::move(arg_steps)),
overload_id_(std::move(overload_id)),
resolver_(std::forward<Resolver>(resolver)) {}

absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result,
Expand Down Expand Up @@ -439,7 +447,8 @@ class DirectFunctionStepImpl : public DirectExpressionStep {
if (resolved_function.has_value() &&
ShouldAcceptOverload(resolved_function->descriptor, args)) {
CEL_ASSIGN_OR_RETURN(result,
Invoke(*resolved_function, expr_id_, args, frame));
Invoke(*resolved_function, expr_id_, args, frame,
overload_id_));

return absl::OkStatus();
}
Expand Down Expand Up @@ -468,6 +477,7 @@ class DirectFunctionStepImpl : public DirectExpressionStep {
friend Resolver;
std::string name_;
std::vector<std::unique_ptr<DirectExpressionStep>> arg_steps_;
std::vector<std::string> overload_id_;
Resolver resolver_;
};

Expand All @@ -476,39 +486,47 @@ class DirectFunctionStepImpl : public DirectExpressionStep {
std::unique_ptr<DirectExpressionStep> CreateDirectFunctionStep(
int64_t expr_id, const cel::CallExpr& call,
std::vector<std::unique_ptr<DirectExpressionStep>> deps,
std::vector<cel::FunctionOverloadReference> overloads) {
std::vector<cel::FunctionOverloadReference> overloads,
std::vector<std::string> overload_id) {
return std::make_unique<DirectFunctionStepImpl<StaticResolver>>(
expr_id, call.function(), std::move(deps),
StaticResolver(std::move(overloads)));
StaticResolver(std::move(overloads)),
std::move(overload_id));
}

std::unique_ptr<DirectExpressionStep> CreateDirectLazyFunctionStep(
int64_t expr_id, const cel::CallExpr& call,
std::vector<std::unique_ptr<DirectExpressionStep>> deps,
std::vector<cel::FunctionRegistry::LazyOverload> providers) {
std::vector<cel::FunctionRegistry::LazyOverload> providers,
std::vector<std::string> overload_id) {
return std::make_unique<DirectFunctionStepImpl<LazyResolver>>(
expr_id, call.function(), std::move(deps),
LazyResolver(std::move(providers), call.function(), call.has_target()));
LazyResolver(std::move(providers), call.function(), call.has_target()),
std::move(overload_id));
}

absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateFunctionStep(
const cel::CallExpr& call_expr, int64_t expr_id,
std::vector<cel::FunctionRegistry::LazyOverload> lazy_overloads) {
std::vector<cel::FunctionRegistry::LazyOverload> lazy_overloads,
std::vector<std::string> overload_id) {
bool receiver_style = call_expr.has_target();
size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0);
const std::string& name = call_expr.function();
return std::make_unique<LazyFunctionStep>(name, num_args, receiver_style,
std::move(lazy_overloads), expr_id);
std::move(lazy_overloads), expr_id,
std::move(overload_id));
}

absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateFunctionStep(
const cel::CallExpr& call_expr, int64_t expr_id,
std::vector<cel::FunctionOverloadReference> overloads) {
std::vector<cel::FunctionOverloadReference> overloads,
std::vector<std::string> overload_id) {
bool receiver_style = call_expr.has_target();
size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0);
const std::string& name = call_expr.function();
return std::make_unique<EagerFunctionStep>(std::move(overloads), name,
num_args, expr_id);
num_args, expr_id,
std::move(overload_id));
}

} // namespace google::api::expr::runtime
12 changes: 8 additions & 4 deletions eval/eval/function_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,32 @@ namespace google::api::expr::runtime {
std::unique_ptr<DirectExpressionStep> CreateDirectFunctionStep(
int64_t expr_id, const cel::CallExpr& call,
std::vector<std::unique_ptr<DirectExpressionStep>> deps,
std::vector<cel::FunctionOverloadReference> overloads);
std::vector<cel::FunctionOverloadReference> overloads,
std::vector<std::string> overload_id = {});

// Factory method for Call-based execution step where the function has been
// statically resolved from a set of lazy functions configured in the
// CelFunctionRegistry.
std::unique_ptr<DirectExpressionStep> CreateDirectLazyFunctionStep(
int64_t expr_id, const cel::CallExpr& call,
std::vector<std::unique_ptr<DirectExpressionStep>> deps,
std::vector<cel::FunctionRegistry::LazyOverload> providers);
std::vector<cel::FunctionRegistry::LazyOverload> providers,
std::vector<std::string> overload_id = {});

// Factory method for Call-based execution step where the function will be
// resolved at runtime (lazily) from an input Activation.
absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateFunctionStep(
const cel::CallExpr& call, int64_t expr_id,
std::vector<cel::FunctionRegistry::LazyOverload> lazy_overloads);
std::vector<cel::FunctionRegistry::LazyOverload> lazy_overloads,
std::vector<std::string> overload_id = {});

// Factory method for Call-based execution step where the function has been
// statically resolved from a set of eagerly functions configured in the
// CelFunctionRegistry.
absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateFunctionStep(
const cel::CallExpr& call, int64_t expr_id,
std::vector<cel::FunctionOverloadReference> overloads);
std::vector<cel::FunctionOverloadReference> overloads,
std::vector<std::string> overload_id = {});

} // namespace google::api::expr::runtime

Expand Down
3 changes: 2 additions & 1 deletion eval/public/cel_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ absl::StatusOr<Value> CelFunction::Invoke(
absl::Span<const cel::Value> arguments,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const {
std::vector<CelValue> legacy_args;
legacy_args.reserve(arguments.size());

Expand Down
3 changes: 2 additions & 1 deletion eval/public/cel_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class CelFunction : public ::cel::Function {
absl::Span<const cel::Value> arguments,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override;
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override;

// CelFunction descriptor
const CelFunctionDescriptor& descriptor() const { return descriptor_; }
Expand Down
3 changes: 2 additions & 1 deletion runtime/activation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class FunctionImpl : public cel::Function {
absl::StatusOr<Value> Invoke(absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull,
google::protobuf::MessageFactory* absl_nonnull,
google::protobuf::Arena* absl_nonnull) const override {
google::protobuf::Arena* absl_nonnull,
absl::Span<const std::string> overload_id) const override {
return NullValue();
}
};
Expand Down
3 changes: 2 additions & 1 deletion runtime/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class Function {
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const = 0;
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id = {}) const = 0;
};

} // namespace cel
Expand Down
15 changes: 10 additions & 5 deletions runtime/function_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class NullaryFunctionAdapter
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
if (args.size() != 0) {
return absl::InvalidArgumentError(
"unexpected number of arguments for nullary function");
Expand Down Expand Up @@ -316,7 +317,8 @@ class UnaryFunctionAdapter : public RegisterHelper<UnaryFunctionAdapter<T, U>> {
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
using ArgTraits = runtime_internal::AdaptedTypeTraits<U>;
if (args.size() != 1) {
return absl::InvalidArgumentError(
Expand Down Expand Up @@ -456,7 +458,8 @@ class BinaryFunctionAdapter
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
using Arg1Traits = runtime_internal::AdaptedTypeTraits<U>;
using Arg2Traits = runtime_internal::AdaptedTypeTraits<V>;
if (args.size() != 2) {
Expand Down Expand Up @@ -537,7 +540,8 @@ class TernaryFunctionAdapter
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
using Arg1Traits = runtime_internal::AdaptedTypeTraits<U>;
using Arg2Traits = runtime_internal::AdaptedTypeTraits<V>;
using Arg3Traits = runtime_internal::AdaptedTypeTraits<W>;
Expand Down Expand Up @@ -624,7 +628,8 @@ class QuaternaryFunctionAdapter
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
using Arg1Traits = runtime_internal::AdaptedTypeTraits<U>;
using Arg2Traits = runtime_internal::AdaptedTypeTraits<V>;
using Arg3Traits = runtime_internal::AdaptedTypeTraits<W>;
Expand Down
3 changes: 2 additions & 1 deletion runtime/function_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class ConstIntFunction : public cel::Function {
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
return IntValue(42);
}
};
Expand Down
3 changes: 2 additions & 1 deletion runtime/optional_types_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ class UnreachableFunction final : public cel::Function {
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
google::protobuf::Arena* absl_nonnull arena,
absl::Span<const std::string> overload_id) const override {
++(*count_);
return ErrorValue{absl::CancelledError()};
}
Expand Down