diff --git a/MODULE.bazel b/MODULE.bazel index c193f5f5b..19fc67613 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -63,7 +63,7 @@ bazel_dep( ) bazel_dep( name = "cel-spec", - version = "0.23.0", + version = "0.24.0", repo_name = "com_google_cel_spec", ) diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 1f8801dfc..97038915a 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -142,10 +142,10 @@ def cel_spec_deps(): url = "https://github.com/bazelbuild/rules_python/releases/download/0.33.2/rules_python-0.33.2.tar.gz", ) - CEL_SPEC_GIT_SHA = "afa18f9bd5a83f5960ca06c1f9faea406ab34ccc" # Dec 2, 2024 + CEL_SPEC_GIT_SHA = "b86370f27c3275e3240a552e10e42b2d658b456e" # Sep 19, 2025 http_archive( name = "com_google_cel_spec", - sha256 = "19b4084ba33cc8da7a640d999e46731efbec585ad2995951dc61a7af24f059cb", + sha256 = "d5558cd419c8d46bdc958064cb97f963d1ea793866414c025906ec15033512ed", strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) diff --git a/common/BUILD b/common/BUILD index d800b36be..dd41f145d 100644 --- a/common/BUILD +++ b/common/BUILD @@ -628,10 +628,12 @@ cc_library( ":native_type", ":optional_ref", ":type", + ":typeinfo", ":unknown", ":value_kind", "//base:attributes", "//common/internal:byte_string", + "//common/internal:reference_count", "//eval/internal:cel_value_equal", "//eval/public:cel_value", "//eval/public:message_wrapper", @@ -656,6 +658,7 @@ cc_library( "//internal:utf8", "//internal:well_known_types", "//runtime:runtime_options", + "//runtime/internal:errors", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", diff --git a/common/value.h b/common/value.h index 0e38646a7..d24490a93 100644 --- a/common/value.h +++ b/common/value.h @@ -39,6 +39,7 @@ #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" +#include "common/typeinfo.h" #include "common/value_kind.h" #include "common/values/bool_value.h" // IWYU pragma: export #include "common/values/bytes_value.h" // IWYU pragma: export @@ -2537,6 +2538,75 @@ ErrorValueAssign::operator()(absl::Status status) const { return common_internal::ImplicitlyConvertibleStatus(); } +inline absl::StatusOr StringValue::Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR( + Join(list, descriptor_pool, message_factory, arena, &result)); + return result; +} + +inline absl::StatusOr StringValue::Split( + const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(Split(delimiter, limit, arena, &result)); + return result; +} + +inline absl::Status StringValue::Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return Split(delimiter, /*limit=*/-1, arena, result); +} + +inline absl::StatusOr StringValue::Split( + const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + return Split(delimiter, /*limit=*/-1, arena); +} + +inline absl::StatusOr StringValue::Replace( + const StringValue& needle, const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(Replace(needle, replacement, limit, arena, &result)); + return result; +} + +inline absl::Status StringValue::Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return Replace(needle, replacement, /*limit=*/-1, arena, result); +} + +inline absl::StatusOr StringValue::Replace( + const StringValue& needle, const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + return Replace(needle, replacement, /*limit=*/-1, arena); +} + namespace common_internal { template diff --git a/common/values/string_value.cc b/common/values/string_value.cc index ba065d275..95166095b 100644 --- a/common/values/string_value.cc +++ b/common/values/string_value.cc @@ -13,24 +13,36 @@ // limitations under the License. #include +#include #include +#include #include +#include +#include +#include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/cord.h" +#include "absl/strings/cord_buffer.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "common/internal/byte_string.h" +#include "common/internal/reference_count.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/utf8.h" #include "internal/well_known_types.h" +#include "runtime/internal/errors.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" @@ -219,4 +231,1309 @@ bool StringValue::Contains(const StringValue& string) const { [&](const absl::Cord& rhs) -> bool { return Contains(rhs); })); } +int64_t StringValue::IndexOf(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> int64_t { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (absl::StartsWith(lhs, string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return -1; + }, + [&](absl::Cord lhs) -> int64_t { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return -1; + })); +} + +int64_t StringValue::IndexOf(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> int64_t { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.substr(0, string.size()) == string) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return -1; + }, + [&](absl::Cord lhs) -> int64_t { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + return code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return -1; + })); +} + +int64_t StringValue::IndexOf(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [this](absl::string_view rhs) -> int64_t { return IndexOf(rhs); }, + [this](const absl::Cord& rhs) -> int64_t { return IndexOf(rhs); })); +} + +Value StringValue::IndexOf(absl::string_view string, int64_t pos) const { + if (pos < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is less than 0")); + } + if (static_cast(pos) > value_.size()) { + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is greater than or equal to " + ".size()")); + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> Value { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && absl::StartsWith(lhs, string)) { + return IntValue(code_points); + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(-1); + } + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is greater than or equal to " + ".size()")); + }, + [&](absl::Cord lhs) -> Value { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && lhs.StartsWith(string)) { + return IntValue(code_points); + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(-1); + } + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is greater than or equal to " + ".size()")); + })); +} + +Value StringValue::IndexOf(const absl::Cord& string, int64_t pos) const { + if (pos < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is less than 0")); + } + if (static_cast(pos) > value_.size()) { + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is greater than or equal to " + ".size()")); + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> Value { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && lhs.substr(0, string.size()) == string) { + return IntValue(code_points); + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(-1); + } + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is greater than or equal to " + ".size()")); + }, + [&](absl::Cord lhs) -> Value { + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (code_points >= pos && lhs.StartsWith(string)) { + return IntValue(code_points); + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(-1); + } + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is greater than or equal to " + ".size()")); + })); +} + +Value StringValue::IndexOf(const StringValue& string, int64_t pos) const { + return string.value_.Visit(absl::Overload( + [this, pos](absl::string_view rhs) -> Value { return IndexOf(rhs, pos); }, + [this, pos](const absl::Cord& rhs) -> Value { + return IndexOf(rhs, pos); + })); +} + +int64_t StringValue::LastIndexOf(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> int64_t { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (absl::StartsWith(lhs, string)) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return last_index; + }, + [&](absl::Cord lhs) -> int64_t { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return last_index; + })); +} + +int64_t StringValue::LastIndexOf(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> int64_t { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.substr(0, string.size()) == string) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + return last_index; + }, + [&](absl::Cord lhs) -> int64_t { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + return last_index; + })); +} + +int64_t StringValue::LastIndexOf(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [this](absl::string_view rhs) -> int64_t { return LastIndexOf(rhs); }, + [this](const absl::Cord& rhs) -> int64_t { return LastIndexOf(rhs); })); +} + +Value StringValue::LastIndexOf(absl::string_view string, int64_t pos) const { + if (pos < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".indexOf(, ): is less than 0")); + } + if (static_cast(pos) > value_.size()) { + return ErrorValue( + absl::InvalidArgumentError(".lastIndexOf(, ): " + " is greater than or equal to " + ".size()")); + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> Value { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (absl::StartsWith(lhs, string)) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(last_index); + } + return ErrorValue( + absl::InvalidArgumentError(".lastIndexOf(, ): " + " is greater than or equal to " + ".size()")); + }, + [&](absl::Cord lhs) -> Value { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(last_index); + } + return ErrorValue( + absl::InvalidArgumentError(".lastIndexOf(, ): " + " is greater than or equal to " + ".size()")); + })); +} + +Value StringValue::LastIndexOf(const absl::Cord& string, int64_t pos) const { + if (pos < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".lastIndexOf(, ): is less than 0")); + } + if (static_cast(pos) > value_.size()) { + return ErrorValue( + absl::InvalidArgumentError(".lastIndexOf(, ): " + " is greater than or equal to " + ".size()")); + } + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> Value { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.substr(0, string.size()) == string) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = + cel::internal::Utf8Decode(lhs, /*code_point=*/nullptr); + lhs.remove_prefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(last_index); + } + return ErrorValue( + absl::InvalidArgumentError(".lastIndexOf(, ): " + " is greater than or equal to " + ".size()")); + }, + [&](absl::Cord lhs) -> Value { + int64_t last_index = -1; + int64_t code_points = 0; + while (lhs.size() >= string.size()) { + if (lhs.StartsWith(string)) { + last_index = code_points; + } + if (code_points >= pos || lhs.size() == string.size()) { + break; + } + size_t code_units = cel::internal::Utf8Decode(lhs.char_begin(), + /*code_point=*/nullptr); + lhs.RemovePrefix(code_units); + ++code_points; + } + if (code_points >= pos) { + return IntValue(last_index); + } + return ErrorValue( + absl::InvalidArgumentError(".lastIndexOf(, ): " + " is greater than or equal to " + ".size()")); + })); +} + +Value StringValue::LastIndexOf(const StringValue& string, int64_t pos) const { + return string.value_.Visit(absl::Overload( + [this, pos](absl::string_view rhs) -> Value { + return LastIndexOf(rhs, pos); + }, + [this, pos](const absl::Cord& rhs) -> Value { + return LastIndexOf(rhs, pos); + })); +} + +namespace { + +absl::StatusOr SubstringImpl(absl::string_view string, uint64_t start) { + size_t size_code_points = 0; + size_t size_code_units = 0; + while (!string.empty()) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(string); + if (size_code_points == start) { + return size_code_units; + } + string.remove_prefix(code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start) { + return size_code_units; + } + return absl::InvalidArgumentError( + ".substring(): is greater than .size()"); +} + +absl::StatusOr SubstringImpl(const absl::Cord& cord, + uint64_t start) { + absl::Cord::CharIterator char_begin = cord.char_begin(); + absl::Cord::CharIterator char_end = cord.char_end(); + size_t size_code_points = 0; + size_t size_code_units = 0; + while (char_begin != char_end) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(char_begin); + if (size_code_points == start) { + return cord.Subcord(size_code_units, std::numeric_limits::max()); + } + absl::Cord::Advance(&char_begin, code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start) { + return cord; + } + return absl::InvalidArgumentError( + ".substring(): is greater than .size()"); +} + +} // namespace + +Value StringValue::Substring(int64_t start) const { + if (start < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(): is less than 0")); + } + if (static_cast(start) > value_.size()) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()")); + } + if (start == 0) { + return *this; + } + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + absl::StatusOr status_or_index = + (SubstringImpl)(value_.GetSmall(), start); + if (!status_or_index.ok()) { + return ErrorValue(std::move(status_or_index).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = value_.rep_.small.size - *status_or_index; + std::memcpy(result.value_.rep_.small.data, + value_.rep_.small.data + *status_or_index, + result.value_.rep_.small.size); + result.value_.rep_.small.arena = value_.rep_.small.arena; + return result; + } + case common_internal::ByteStringKind::kMedium: { + absl::StatusOr status_or_index = + (SubstringImpl)(value_.GetMedium(), start); + if (!status_or_index.ok()) { + return ErrorValue(std::move(status_or_index).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; + result.value_.rep_.medium.size = + value_.rep_.medium.size - *status_or_index; + result.value_.rep_.medium.data = + value_.rep_.medium.data + *status_or_index; + result.value_.rep_.medium.owner = value_.rep_.medium.owner; + common_internal::StrongRef(result.value_.GetMediumReferenceCount()); + return result; + } + case common_internal::ByteStringKind::kLarge: { + absl::StatusOr status_or_cord = + (SubstringImpl)(value_.GetLarge(), start); + if (!status_or_cord.ok()) { + return ErrorValue(std::move(status_or_cord).status()); + } + return StringValue::Wrap(*std::move(status_or_cord)); + } + } +} + +namespace { + +absl::StatusOr> SubstringImpl( + absl::string_view string, uint64_t start, uint64_t end) { + size_t size_code_points = 0; + size_t size_code_units = 0; + size_t start_code_units; + while (!string.empty()) { + if (size_code_points == start) { + start_code_units = size_code_units; + } + if (size_code_points == end) { + return std::pair{start_code_units, size_code_units}; + } + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(string); + string.remove_prefix(code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start && start == end) { + return std::pair{size_code_units, size_code_units}; + } + return absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()"); +} + +absl::StatusOr SubstringImpl(const absl::Cord& cord, uint64_t start, + uint64_t end) { + absl::Cord::CharIterator char_begin = cord.char_begin(); + absl::Cord::CharIterator char_end = cord.char_end(); + size_t size_code_points = 0; + size_t size_code_units = 0; + size_t start_code_units; + while (char_begin != char_end) { + if (size_code_points == start) { + start_code_units = size_code_units; + } + if (size_code_points == end) { + return cord.Subcord(start_code_units, + size_code_points - start_code_units); + } + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(char_begin); + absl::Cord::Advance(&char_begin, code_units); + ++size_code_points; + size_code_units += code_units; + } + if (size_code_points == start && start == end) { + return absl::Cord(); + } + return absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()"); +} + +} // namespace + +Value StringValue::Substring(int64_t start, int64_t end) const { + if (start < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): is less than 0")); + } + if (end < start) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): is less than ")); + } + if (static_cast(start) > value_.size() || + static_cast(end) > value_.size()) { + return ErrorValue(absl::InvalidArgumentError( + ".substring(, ): or is greater than " + ".size()")); + } + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + absl::StatusOr> status_or_indices = + (SubstringImpl)(value_.GetSmall(), start, end); + if (!status_or_indices.ok()) { + return ErrorValue(std::move(status_or_indices).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = + (status_or_indices->second - status_or_indices->first); + std::memcpy(result.value_.rep_.small.data, + value_.rep_.small.data + status_or_indices->first, + result.value_.rep_.small.size); + result.value_.rep_.small.arena = value_.rep_.small.arena; + return result; + } + case common_internal::ByteStringKind::kMedium: { + absl::StatusOr> status_or_indices = + (SubstringImpl)(value_.GetMedium(), start, end); + if (!status_or_indices.ok()) { + return ErrorValue(std::move(status_or_indices).status()); + } + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; + result.value_.rep_.medium.size = + (status_or_indices->second - status_or_indices->first); + result.value_.rep_.medium.data = + value_.rep_.medium.data + status_or_indices->first; + result.value_.rep_.medium.owner = value_.rep_.medium.owner; + common_internal::StrongRef(result.value_.GetMediumReferenceCount()); + return result; + } + case common_internal::ByteStringKind::kLarge: { + absl::StatusOr status_or_cord = + (SubstringImpl)(value_.GetLarge(), start, end); + if (!status_or_cord.ok()) { + return ErrorValue(std::move(status_or_cord).status()); + } + return StringValue::Wrap(*std::move(status_or_cord)); + } + } +} + +namespace { + +bool LowerAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { + if (in.empty()) { + return false; + } + size_t pos; + for (pos = 0; pos < in.size(); ++pos) { + if (absl::ascii_isupper(in[pos])) { + break; + } + } + if (pos == in.size()) { + return false; + } + out->resize(in.size()); + char* out_data = out->data(); + if (pos > 0) { + std::memcpy(out_data, in.data(), pos); + } + for (size_t i = pos; i < in.size(); ++i) { + out_data[i] = absl::ascii_tolower(in[i]); + } + return true; +} + +absl::Cord LowerAsciiImpl(const absl::Cord& in) { + if (in.empty()) { + return in; + } + size_t pos; + absl::Cord::CharIterator begin = in.char_begin(); + absl::Cord::CharIterator end = in.char_end(); + for (pos = 0; begin != end; ++pos, ++begin) { + if (absl::ascii_isupper(*begin)) { + break; + } + } + if (begin == end) { + return in; + } + absl::Cord out = in.Subcord(0, pos); + size_t n = in.size() - pos; + bool first = true; + while (begin != end) { + absl::CordBuffer buffer = first + ? out.GetAppendBuffer(n) + : absl::CordBuffer::CreateWithDefaultLimit(n); + absl::Span data = buffer.available_up_to(n); + size_t i; + for (i = 0; i < data.size() && begin != end; ++i, ++begin) { + data[i] = absl::ascii_tolower(*begin); + } + buffer.IncreaseLengthBy(i); + out.Append(std::move(buffer)); + n -= i; + first = false; + } + return out; +} + +} // namespace + +StringValue StringValue::LowerAscii(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::string out; + if (!(LowerAsciiImpl)(value_.GetSmall(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kMedium: { + std::string out; + if (!(LowerAsciiImpl)(value_.GetMedium(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kLarge: + return StringValue::Wrap((LowerAsciiImpl)(value_.GetLarge())); + } +} + +namespace { + +bool UpperAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { + if (in.empty()) { + return false; + } + size_t pos; + for (pos = 0; pos < in.size(); ++pos) { + if (absl::ascii_islower(in[pos])) { + break; + } + } + if (pos == in.size()) { + return false; + } + out->resize(in.size()); + char* out_data = out->data(); + for (size_t i = 0; i < in.size(); ++i) { + out_data[i] = absl::ascii_toupper(in[i]); + } + return true; +} + +absl::Cord UpperAsciiImpl(const absl::Cord& in) { + if (in.empty()) { + return in; + } + size_t pos; + absl::Cord::CharIterator begin = in.char_begin(); + absl::Cord::CharIterator end = in.char_end(); + for (pos = 0; begin != end; ++pos, ++begin) { + if (absl::ascii_islower(*begin)) { + break; + } + } + if (begin == end) { + return in; + } + absl::Cord out = in.Subcord(0, pos); + size_t n = in.size() - pos; + bool first = true; + while (begin != end) { + absl::CordBuffer buffer = first + ? out.GetAppendBuffer(n) + : absl::CordBuffer::CreateWithDefaultLimit(n); + absl::Span data = buffer.available_up_to(n); + size_t i; + for (i = 0; i < data.size() && begin != end; ++i, ++begin) { + data[i] = absl::ascii_toupper(*begin); + } + buffer.IncreaseLengthBy(i); + out.Append(std::move(buffer)); + n -= i; + first = false; + } + return out; +} + +} // namespace + +StringValue StringValue::UpperAscii(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::string out; + if (!(UpperAsciiImpl)(value_.GetSmall(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kMedium: { + std::string out; + if (!(UpperAsciiImpl)(value_.GetMedium(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kLarge: + return StringValue::Wrap((UpperAsciiImpl)(value_.GetLarge())); + } +} + +namespace { + +// Per CEL spec, checking for Unicode whitespace. +bool IsUnicodeWhitespace(char32_t c) { + if (c <= 0x0020) { + return c == 0x0020 || (c >= 0x0009 && c <= 0x000D); + } + if (c > 0x3000) return false; + if (c == 0x0085 || c == 0x00a0 || c == 0x1680) return true; + if (c >= 0x2000 && c <= 0x200a) return true; + return c == 0x2028 || c == 0x2029 || c == 0x202f || c == 0x205f || + c == 0x3000; +} + +std::pair TrimImpl(absl::string_view string) { + absl::string_view temp_string = string; + size_t left_trim_bytes = 0; + while (!temp_string.empty()) { + char32_t c; + size_t char_len = cel::internal::Utf8Decode(temp_string, &c); + if (!IsUnicodeWhitespace(c)) { + break; + } + temp_string.remove_prefix(char_len); + left_trim_bytes += char_len; + } + + if (left_trim_bytes == string.size()) { + return {left_trim_bytes, 0}; + } + + size_t last_non_ws_end_bytes = 0; + size_t current_pos_bytes = 0; + temp_string = string; + while (!temp_string.empty()) { + char32_t c; + size_t char_len = cel::internal::Utf8Decode(temp_string, &c); + if (!IsUnicodeWhitespace(c)) { + last_non_ws_end_bytes = current_pos_bytes + char_len; + } + current_pos_bytes += char_len; + temp_string.remove_prefix(char_len); + } + + return {left_trim_bytes, string.size() - last_non_ws_end_bytes}; +} + +absl::Cord TrimImpl(const absl::Cord& cord) { + size_t left_trim_bytes = 0; + { + absl::Cord::CharIterator begin = cord.char_begin(); + const absl::Cord::CharIterator end = cord.char_end(); + while (begin != end) { + char32_t c; + size_t char_len; + std::tie(c, char_len) = cel::internal::Utf8Decode(begin); + if (!IsUnicodeWhitespace(c)) { + break; + } + absl::Cord::Advance(&begin, char_len); + left_trim_bytes += char_len; + } + } + + if (left_trim_bytes == cord.size()) { + return absl::Cord(); + } + + absl::Cord ltrimmed = + cord.Subcord(left_trim_bytes, cord.size() - left_trim_bytes); + + size_t last_non_ws_end_bytes = 0; + size_t current_pos_bytes = 0; + { + absl::Cord::CharIterator begin = ltrimmed.char_begin(); + const absl::Cord::CharIterator end = ltrimmed.char_end(); + while (begin != end) { + char32_t c; + size_t char_len; + std::tie(c, char_len) = cel::internal::Utf8Decode(begin); + if (!IsUnicodeWhitespace(c)) { + last_non_ws_end_bytes = current_pos_bytes + char_len; + } + absl::Cord::Advance(&begin, char_len); + current_pos_bytes += char_len; + } + } + return ltrimmed.Subcord(0, last_non_ws_end_bytes); +} + +} // namespace + +StringValue StringValue::Trim() const { + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::pair trims = (TrimImpl)(value_.GetSmall()); + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = + value_.rep_.small.size - trims.first - trims.second; + std::memcpy(result.value_.rep_.small.data, + value_.rep_.small.data + trims.first, + result.value_.rep_.small.size); + result.value_.rep_.small.arena = value_.GetSmallArena(); + return result; + } + case common_internal::ByteStringKind::kMedium: { + std::pair trims = (TrimImpl)(value_.GetMedium()); + StringValue result; + result.value_.rep_.header.kind = common_internal::ByteStringKind::kMedium; + result.value_.rep_.medium.size = + value_.rep_.medium.size - trims.first - trims.second; + result.value_.rep_.medium.data = value_.rep_.medium.data + trims.first; + result.value_.rep_.medium.owner = value_.rep_.medium.owner; + common_internal::StrongRef(result.value_.GetMediumReferenceCount()); + return result; + } + case common_internal::ByteStringKind::kLarge: { + return StringValue::Wrap((TrimImpl)(value_.GetLarge())); + } + } +} + +StringValue StringValue::Quote(google::protobuf::Arena* absl_nonnull arena) const { + return value_.Visit(absl::Overload( + [&](absl::string_view rhs) -> StringValue { + std::string result; + result.push_back('\"'); + while (!rhs.empty()) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(rhs); + switch (code_point) { + case '\a': + result.append("\\a"); + break; + case '\b': + result.append("\\b"); + break; + case '\f': + result.append("\\f"); + break; + case '\n': + result.append("\\n"); + break; + case '\r': + result.append("\\r"); + break; + case '\t': + result.append("\\t"); + break; + case '\v': + result.append("\\v"); + break; + case '\\': + result.append("\\\\"); + break; + case '\"': + result.append("\\\""); + break; + default: + cel::internal::Utf8Encode(code_point, &result); + break; + } + rhs.remove_prefix(code_units); + } + result.push_back('\"'); + return StringValue::From(std::move(result), arena); + }, + [&](const absl::Cord& rhs) -> StringValue { + absl::Cord::CharIterator begin = rhs.char_begin(); + absl::Cord::CharIterator end = rhs.char_end(); + std::string result; + result.push_back('\"'); + while (begin != end) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(begin); + switch (code_point) { + case '\a': + result.append("\\a"); + break; + case '\b': + result.append("\\b"); + break; + case '\f': + result.append("\\f"); + break; + case '\n': + result.append("\\n"); + break; + case '\r': + result.append("\\r"); + break; + case '\t': + result.append("\\t"); + break; + case '\v': + result.append("\\v"); + break; + case '\\': + result.append("\\\\"); + break; + case '\"': + result.append("\\\""); + break; + default: + cel::internal::Utf8Encode(code_point, &result); + break; + } + absl::Cord::Advance(&begin, code_units); + } + result.push_back('\"'); + return StringValue::From(std::move(result), arena); + })); +} + +StringValue StringValue::Reverse(google::protobuf::Arena* absl_nonnull arena) const { + return value_.Visit(absl::Overload( + [arena](absl::string_view string) -> StringValue { + if (string.empty()) { + return StringValue(); + } + std::string reversed; + reversed.reserve(string.size()); + const char* ptr = string.data() + string.size(); + const char* begin = string.data(); + while (ptr > begin) { + const char* char_end = ptr; + --ptr; + while (ptr > begin && (*ptr & 0xC0) == 0x80) { + --ptr; + } + reversed.append(ptr, char_end - ptr); + } + return StringValue::From(std::move(reversed), arena); + }, + [&](const absl::Cord& cord) -> StringValue { + if (cord.empty()) { + return StringValue(); + } + std::vector code_points; + absl::Cord::CharIterator char_begin = cord.char_begin(); + absl::Cord::CharIterator char_end = cord.char_end(); + size_t current_pos = 0; + while (char_begin != char_end) { + size_t code_units = + cel::internal::Utf8Decode(char_begin, /*code_point=*/nullptr); + code_points.push_back(cord.Subcord(current_pos, code_units)); + absl::Cord::Advance(&char_begin, code_units); + current_pos += code_units; + } + absl::Cord result; + for (auto it = code_points.rbegin(); it != code_points.rend(); ++it) { + result.Append(*it); + } + return StringValue::Wrap(std::move(result)); + })); +} + +absl::Status StringValue::Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + std::string joined; + + CEL_ASSIGN_OR_RETURN(auto iterator, list.NewIterator()); + + CEL_ASSIGN_OR_RETURN( + absl::optional element, + iterator->Next1(descriptor_pool, message_factory, arena)); + if (element) { + if (auto string_element = element->AsString(); string_element) { + string_element->AppendToString(&joined); + } else { + ABSL_DCHECK(!element->Is()); + *result = + ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); + return absl::OkStatus(); + } + while (true) { + CEL_ASSIGN_OR_RETURN( + element, iterator->Next1(descriptor_pool, message_factory, arena)); + if (!element) { + break; + } + AppendToString(&joined); + if (auto string_element = element->AsString(); string_element) { + string_element->AppendToString(&joined); + } else { + ABSL_DCHECK(!element->Is()); + *result = + ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); + return absl::OkStatus(); + } + } + } + + if (joined.size() > common_internal::kSmallByteStringCapacity) { + joined.shrink_to_fit(); + } + + *result = StringValue::From(std::move(joined), arena); + return absl::OkStatus(); +} + +absl::Status StringValue::Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (limit == 0) { + // Per spec, when limit is 0 return an empty list. + *result = ListValue(); + return absl::OkStatus(); + } + if (limit < 0) { + // Per spec, when limit is negative treat it as unlimited splits. + limit = std::numeric_limits::max(); + } + + std::vector> splits; + size_t pos = 0; + const size_t len = value_.size(); + + if (delimiter.IsEmpty()) { + while (pos < len && limit > 1) { + size_t char_len = 1; + value_.Visit(absl::Overload( + [&](absl::string_view s) { + char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); + }, + [&](const absl::Cord& s) { + char_len = cel::internal::Utf8Decode( + s.Subcord(pos, len - pos).char_begin(), nullptr); + })); + splits.push_back({pos, pos + char_len}); + pos += char_len; + --limit; + } + } else { + while (pos < len && limit > 1) { + absl::optional next = value_.Find(delimiter.value_, pos); + if (!next) { + break; + } + splits.push_back(std::pair{pos, *next}); + pos = *next + delimiter.value_.size(); + --limit; + ABSL_DCHECK_LE(pos, len); + } + } + + if (splits.empty() || !delimiter.IsEmpty() || pos < len) { + splits.push_back(std::pair{pos, len}); + } + + auto builder = NewListValueBuilder(arena); + builder->Reserve(splits.size()); + for (const std::pair& split : splits) { + builder->UnsafeAdd( + StringValue(value_.Substring(split.first, split.second))); + } + *result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +absl::Status StringValue::Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (limit == 0) { + // Per spec, when limit is 0 return the original string. + *result = *this; + return absl::OkStatus(); + } + if (limit < 0) { + // Per spec, when limit is negative treat it as unlimited replacements. + limit = std::numeric_limits::max(); + } + + size_t pos = 0; + const size_t len = value_.size(); + const size_t needle_len = needle.value_.size(); + std::string res_str; + + if (needle.IsEmpty()) { + while (pos < len && limit > 0) { + replacement.AppendToString(&res_str); + + size_t char_len = 1; + value_.Visit(absl::Overload( + [&](absl::string_view s) { + char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); + }, + [&](const absl::Cord& s) { + char_len = cel::internal::Utf8Decode( + s.Subcord(pos, len - pos).char_begin(), nullptr); + })); + value_.Substring(pos, char_len).AppendToString(&res_str); + pos += char_len; + --limit; + } + if (limit > 0) { + replacement.AppendToString(&res_str); + } + } else { + while (pos < len && limit > 0) { + absl::optional next = value_.Find(needle.value_, pos); + if (!next) { + break; + } + + value_.Substring(pos, *next).AppendToString(&res_str); + replacement.AppendToString(&res_str); + + pos = *next + needle_len; + --limit; + } + } + + if (pos < len) { + value_.Substring(pos, len).AppendToString(&res_str); + } + + if (res_str.size() > common_internal::kSmallByteStringCapacity) { + res_str.shrink_to_fit(); + } + + *result = StringValue::From(std::move(res_str), arena); + return absl::OkStatus(); +} + +Value StringValue::CharAt(int64_t pos) const { + if (pos < 0) { + return ErrorValue(absl::InvalidArgumentError( + ".charAt(): is less than 0")); + } + return value_.Visit(absl::Overload( + [this, pos](absl::string_view rhs) mutable -> Value { + size_t size = 0; + while (!rhs.empty()) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(rhs); + if (pos == 0) { + StringValue result; + result.value_.rep_.header.kind = + common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = cel::internal::Utf8Encode( + code_point, result.value_.rep_.small.data); + result.value_.rep_.small.arena = value_.GetArena(); + return result; + } + rhs.remove_prefix(code_units); + --pos; + ++size; + } + if (pos == 0) { + return StringValue(); + } + return ErrorValue(absl::InvalidArgumentError( + ".charAt(): is greater than .size()")); + }, + [pos](const absl::Cord& rhs) mutable -> Value { + absl::Cord::CharIterator begin = rhs.char_begin(); + absl::Cord::CharIterator end = rhs.char_end(); + size_t size = 0; + while (begin != end) { + char32_t code_point; + size_t code_units; + std::tie(code_point, code_units) = cel::internal::Utf8Decode(begin); + if (pos == 0) { + StringValue result; + result.value_.rep_.header.kind = + common_internal::ByteStringKind::kSmall; + result.value_.rep_.small.size = cel::internal::Utf8Encode( + code_point, result.value_.rep_.small.data); + result.value_.rep_.small.arena = nullptr; + return result; + } + absl::Cord::Advance(&begin, code_units); + --pos; + ++size; + } + if (pos == 0) { + return StringValue(); + } + return ErrorValue(absl::InvalidArgumentError( + ".charAt(): is greater than .size()")); + })); +} + } // namespace cel diff --git a/common/values/string_value.h b/common/values/string_value.h index f7dcfc8d1..58b33bc8d 100644 --- a/common/values/string_value.h +++ b/common/values/string_value.h @@ -19,6 +19,7 @@ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRING_VALUE_H_ #include +#include #include #include #include @@ -28,6 +29,7 @@ #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -46,6 +48,7 @@ namespace cel { class Value; +class ListValue; class StringValue; namespace common_internal { @@ -208,6 +211,73 @@ class StringValue final : private common_internal::ValueMixin { bool Contains(const absl::Cord& string) const; bool Contains(const StringValue& string) const; + int64_t IndexOf(absl::string_view string) const; + int64_t IndexOf(const absl::Cord& string) const; + int64_t IndexOf(const StringValue& string) const; + Value IndexOf(absl::string_view string, int64_t pos) const; + Value IndexOf(const absl::Cord& string, int64_t pos) const; + Value IndexOf(const StringValue& string, int64_t pos) const; + + int64_t LastIndexOf(absl::string_view string) const; + int64_t LastIndexOf(const absl::Cord& string) const; + int64_t LastIndexOf(const StringValue& string) const; + Value LastIndexOf(absl::string_view string, int64_t pos) const; + Value LastIndexOf(const absl::Cord& string, int64_t pos) const; + Value LastIndexOf(const StringValue& string, int64_t pos) const; + + Value Substring(int64_t start) const; + + Value Substring(int64_t start, int64_t end) const; + + StringValue LowerAscii(google::protobuf::Arena* absl_nonnull arena) const; + + StringValue UpperAscii(google::protobuf::Arena* absl_nonnull arena) const; + + StringValue Trim() const; + + StringValue Quote(google::protobuf::Arena* absl_nonnull arena) const; + + StringValue Reverse(google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Join(const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const; + absl::Status Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const; + absl::Status Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena) const; + + Value CharAt(int64_t pos) const; + absl::optional TryFlat() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.TryFlat(); diff --git a/common/values/string_value_test.cc b/common/values/string_value_test.cc index 244fd3f7e..b1f062bae 100644 --- a/common/values/string_value_test.cc +++ b/common/values/string_value_test.cc @@ -16,6 +16,7 @@ #include #include "absl/hash/hash.h" +#include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" @@ -208,5 +209,182 @@ TEST_F(StringValueTest, Contains) { .Contains(StringValue(absl::Cord("string is large enough")))); } +TEST_F(StringValueTest, LowerAscii) { + EXPECT_EQ(StringValue("UPPER lower").LowerAscii(arena()), "upper lower"); + EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).LowerAscii(arena()), + "upper lower"); + EXPECT_EQ(StringValue("upper lower").LowerAscii(arena()), "upper lower"); + EXPECT_EQ(StringValue(absl::Cord("upper lower")).LowerAscii(arena()), + "upper lower"); + EXPECT_EQ(StringValue("").LowerAscii(arena()), ""); + EXPECT_EQ(StringValue(absl::Cord("")).LowerAscii(arena()), ""); + const std::string kLongMixed = + "A long STRING with MiXeD case to test conversion to lower case!"; + const std::string kLongLower = + "a long string with mixed case to test conversion to lower case!"; + EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).LowerAscii(arena()), + kLongLower); + std::string very_long_mixed(10000, 'A'); + std::string very_long_lower(10000, 'a'); + EXPECT_EQ( + StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), + very_long_mixed.substr(5000)})) + .LowerAscii(arena()), + very_long_lower); +} + +TEST_F(StringValueTest, UpperAscii) { + EXPECT_EQ(StringValue("UPPER lower").UpperAscii(arena()), "UPPER LOWER"); + EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).UpperAscii(arena()), + "UPPER LOWER"); + EXPECT_EQ(StringValue("UPPER LOWER").UpperAscii(arena()), "UPPER LOWER"); + EXPECT_EQ(StringValue(absl::Cord("UPPER LOWER")).UpperAscii(arena()), + "UPPER LOWER"); + EXPECT_EQ(StringValue("").UpperAscii(arena()), ""); + EXPECT_EQ(StringValue(absl::Cord("")).UpperAscii(arena()), ""); + const std::string kLongMixed = + "A long STRING with MiXeD case to test conversion to UPPER case!"; + const std::string kLongUpper = + "A LONG STRING WITH MIXED CASE TO TEST CONVERSION TO UPPER CASE!"; + EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).UpperAscii(arena()), + kLongUpper); + std::string very_long_mixed(10000, 'a'); + std::string very_long_upper(10000, 'A'); + EXPECT_EQ( + StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), + very_long_mixed.substr(5000)})) + .UpperAscii(arena()), + very_long_upper); +} + +TEST_F(StringValueTest, LastIndexOf) { + using ::cel::test::ErrorValueIs; + using ::cel::test::IntValueIs; + StringValue big_string = + StringValue("This string is large enough to not be stored inline!"); + StringValue big_string_cord = StringValue( + absl::Cord("This string is large enough to not be stored inline!")); + StringValue small_string = StringValue("is"); + StringValue small_string_cord = StringValue(absl::Cord("is")); + + EXPECT_EQ(big_string.LastIndexOf(small_string), 12); + EXPECT_EQ(big_string.LastIndexOf(small_string_cord), 12); + EXPECT_EQ(big_string_cord.LastIndexOf(small_string), 12); + EXPECT_EQ(big_string_cord.LastIndexOf(small_string_cord), 12); + + EXPECT_EQ(big_string.LastIndexOf("is"), 12); + EXPECT_EQ(big_string_cord.LastIndexOf("is"), 12); + + EXPECT_THAT(big_string.LastIndexOf(small_string, 4), IntValueIs(2)); + EXPECT_THAT(big_string.LastIndexOf(small_string_cord, 4), IntValueIs(2)); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string, 4), IntValueIs(2)); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord, 4), IntValueIs(2)); + + EXPECT_THAT(big_string.LastIndexOf("is", 4), IntValueIs(2)); + EXPECT_THAT(big_string_cord.LastIndexOf("is", 4), IntValueIs(2)); + + EXPECT_THAT(big_string.LastIndexOf(small_string, 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string.LastIndexOf(small_string_cord, 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string, 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string_cord.LastIndexOf(small_string_cord, 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string.LastIndexOf(absl::Cord("is"), 4), IntValueIs(2)); + EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord("is"), 4), IntValueIs(2)); + EXPECT_THAT(big_string.LastIndexOf(absl::Cord("is"), 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord("is"), 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string.LastIndexOf(absl::Cord(""), 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); + EXPECT_THAT(big_string_cord.LastIndexOf(absl::Cord(""), 100), + ErrorValueIs(absl::InvalidArgumentError( + ".lastIndexOf(, ): is greater than " + "or equal to .size()"))); +} + +TEST_F(StringValueTest, Trim) { + using ::cel::test::StringValueIs; + StringValue unpadded = StringValue("no padding"); + StringValue front_padded = StringValue(" \t\r\nno padding"); + StringValue back_padded = StringValue("no padding \t\r\n"); + StringValue both_padded = StringValue(" \t\r\nno padding \t\r\n"); + StringValue whitespace = StringValue(" \t\r\n"); + StringValue empty = StringValue(""); + + EXPECT_THAT(unpadded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(front_padded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(back_padded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(both_padded.Trim(), StringValueIs("no padding")); + EXPECT_THAT(whitespace.Trim(), StringValueIs("")); + EXPECT_THAT(empty.Trim(), StringValueIs("")); + + StringValue unpadded_cord = StringValue(absl::Cord("no padding")); + StringValue front_padded_cord = StringValue(absl::Cord(" \t\r\nno padding")); + StringValue back_padded_cord = StringValue(absl::Cord("no padding \t\r\n")); + StringValue both_padded_cord = + StringValue(absl::Cord(" \t\r\nno padding \t\r\n")); + StringValue whitespace_cord = StringValue(absl::Cord(" \t\r\n")); + StringValue empty_cord = StringValue(absl::Cord("")); + + EXPECT_THAT(unpadded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(front_padded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(back_padded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(both_padded_cord.Trim(), StringValueIs("no padding")); + EXPECT_THAT(whitespace_cord.Trim(), StringValueIs("")); + EXPECT_THAT(empty_cord.Trim(), StringValueIs("")); +} + +TEST_F(StringValueTest, CharAt) { + using ::cel::test::ErrorValueIs; + using ::cel::test::StringValueIs; + StringValue big_string = + StringValue("This string is large enough to not be stored inline!"); + StringValue big_string_cord = StringValue( + absl::Cord("This string is large enough to not be stored inline!")); + StringValue small_string = StringValue("abc"); + StringValue small_string_cord = StringValue(absl::Cord("abc")); + StringValue unicode_string = StringValue("aμc"); + StringValue unicode_string_cord = StringValue(absl::Cord("aμc")); + + EXPECT_THAT(big_string.CharAt(0), StringValueIs("T")); + EXPECT_THAT(big_string_cord.CharAt(0), StringValueIs("T")); + EXPECT_THAT(small_string.CharAt(1), StringValueIs("b")); + EXPECT_THAT(small_string_cord.CharAt(1), StringValueIs("b")); + EXPECT_THAT(unicode_string.CharAt(1), StringValueIs("μ")); + EXPECT_THAT(unicode_string_cord.CharAt(1), StringValueIs("μ")); + + EXPECT_THAT( + big_string.CharAt(100), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is greater than .size()"))); + EXPECT_THAT( + big_string_cord.CharAt(100), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is greater than .size()"))); + EXPECT_THAT(big_string.CharAt(-1), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is less than 0"))); + EXPECT_THAT(big_string_cord.CharAt(-1), + ErrorValueIs(absl::InvalidArgumentError( + ".charAt(): is less than 0"))); +} + } // namespace } // namespace cel diff --git a/conformance/BUILD b/conformance/BUILD index 74390abc0..41e8e08fa 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -191,17 +191,6 @@ _TESTS_TO_SKIP_MODERN = [ "enums/strong_proto3", # Not yet implemented. - "string_ext/char_at", - "string_ext/index_of", - "string_ext/last_index_of", - "string_ext/ascii_casing/upperascii", - "string_ext/ascii_casing/upperascii_unicode", - "string_ext/ascii_casing/upperascii_unicode_with_space", - "string_ext/replace", - "string_ext/substring", - "string_ext/trim", - "string_ext/quote", - "string_ext/value_errors", "string_ext/type_errors", ] @@ -243,17 +232,6 @@ _TESTS_TO_SKIP_LEGACY = [ "optionals/optionals", # Not yet implemented. - "string_ext/char_at", - "string_ext/index_of", - "string_ext/last_index_of", - "string_ext/ascii_casing/upperascii", - "string_ext/ascii_casing/upperascii_unicode", - "string_ext/ascii_casing/upperascii_unicode_with_space", - "string_ext/replace", - "string_ext/substring", - "string_ext/trim", - "string_ext/quote", - "string_ext/value_errors", "string_ext/type_errors", # TODO(uncreated-issue/81): Fix null assignment to a field diff --git a/conformance/run.bzl b/conformance/run.bzl index 8205fa987..b984ef3a1 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -86,8 +86,11 @@ def gen_conformance_tests(name, data, modern = False, checked = False, dashboard dashboard: enable dashboard mode """ skip_check = not checked + tests = [] for optimize in (True, False): for recursive in (True, False): + test_name = _conformance_test_name(name, optimize, recursive) + tests.append(test_name) _conformance_test( name, data, @@ -99,3 +102,8 @@ def gen_conformance_tests(name, data, modern = False, checked = False, dashboard tags = tags, dashboard = dashboard, ) + native.test_suite( + name = name, + tests = tests, + tags = tags, + ) diff --git a/extensions/BUILD b/extensions/BUILD index 52d25a888..5d62a9f85 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -511,16 +511,13 @@ cc_library( "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", - "//internal:utf8", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", - "//runtime/internal:errors", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", @@ -536,6 +533,7 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:decl", + "//common:type", "//common:value", "//compiler:compiler_factory", "//compiler:standard_library", @@ -550,6 +548,7 @@ cc_test( "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "//testutil:baseline_tests", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings:cord", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", diff --git a/extensions/strings.cc b/extensions/strings.cc index 3f9c73a33..5db94d7ab 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -14,18 +14,14 @@ #include "extensions/strings.h" -#include #include -#include #include -#include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "checker/internal/builtins_arena.h" @@ -37,10 +33,8 @@ #include "eval/public/cel_options.h" #include "extensions/formatting.h" #include "internal/status_macros.h" -#include "internal/utf8.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" -#include "runtime/internal/errors.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -67,35 +61,7 @@ absl::StatusOr Join2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - std::string result; - CEL_ASSIGN_OR_RETURN(auto iterator, value.NewIterator()); - Value element; - if (iterator->HasNext()) { - CEL_RETURN_IF_ERROR( - iterator->Next(descriptor_pool, message_factory, arena, &element)); - if (auto string_element = element.AsString(); string_element) { - string_element->NativeValue(AppendToStringVisitor{result}); - } else { - return ErrorValue{ - runtime_internal::CreateNoMatchingOverloadError("join")}; - } - } - std::string separator_scratch; - absl::string_view separator_view = separator.NativeString(separator_scratch); - while (iterator->HasNext()) { - result.append(separator_view); - CEL_RETURN_IF_ERROR( - iterator->Next(descriptor_pool, message_factory, arena, &element)); - if (auto string_element = element.AsString(); string_element) { - string_element->NativeValue(AppendToStringVisitor{result}); - } else { - return ErrorValue{ - runtime_internal::CreateNoMatchingOverloadError("join")}; - } - } - result.shrink_to_fit(); - // We assume the original string was well-formed. - return StringValue(arena, std::move(result)); + return separator.Join(value, descriptor_pool, message_factory, arena); } absl::StatusOr Join1( @@ -103,117 +69,15 @@ absl::StatusOr Join1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - return Join2(value, StringValue{}, descriptor_pool, message_factory, arena); + return StringValue().Join(value, descriptor_pool, message_factory, arena); } -struct SplitWithEmptyDelimiter { - google::protobuf::Arena* absl_nonnull arena; - int64_t& limit; - ListValueBuilder& builder; - - absl::StatusOr operator()(absl::string_view string) const { - char32_t rune; - size_t count; - std::string buffer; - buffer.reserve(4); - while (!string.empty() && limit > 1) { - std::tie(rune, count) = internal::Utf8Decode(string); - buffer.clear(); - internal::Utf8Encode(buffer, rune); - CEL_RETURN_IF_ERROR( - builder.Add(StringValue(arena, absl::string_view(buffer)))); - --limit; - string.remove_prefix(count); - } - if (!string.empty()) { - CEL_RETURN_IF_ERROR(builder.Add(StringValue(arena, string))); - } - return std::move(builder).Build(); - } - - absl::StatusOr operator()(const absl::Cord& string) const { - auto begin = string.char_begin(); - auto end = string.char_end(); - char32_t rune; - size_t count; - std::string buffer; - while (begin != end && limit > 1) { - std::tie(rune, count) = internal::Utf8Decode(begin); - buffer.clear(); - internal::Utf8Encode(buffer, rune); - CEL_RETURN_IF_ERROR( - builder.Add(StringValue(arena, absl::string_view(buffer)))); - --limit; - absl::Cord::Advance(&begin, count); - } - if (begin != end) { - buffer.clear(); - while (begin != end) { - auto chunk = absl::Cord::ChunkRemaining(begin); - buffer.append(chunk); - absl::Cord::Advance(&begin, chunk.size()); - } - buffer.shrink_to_fit(); - CEL_RETURN_IF_ERROR(builder.Add(StringValue(arena, std::move(buffer)))); - } - return std::move(builder).Build(); - } -}; - absl::StatusOr Split3( const StringValue& string, const StringValue& delimiter, int64_t limit, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - if (limit == 0) { - // Per spec, when limit is 0 return an empty list. - return ListValue{}; - } - if (limit < 0) { - // Per spec, when limit is negative treat is as unlimited. - limit = std::numeric_limits::max(); - } - auto builder = NewListValueBuilder(arena); - if (string.IsEmpty()) { - // If string is empty, it doesn't matter what the delimiter is or the limit. - // We just return a list with a single empty string. - builder->Reserve(1); - CEL_RETURN_IF_ERROR(builder->Add(StringValue{})); - return std::move(*builder).Build(); - } - if (delimiter.IsEmpty()) { - // If the delimiter is empty, we split between every code point. - return string.NativeValue(SplitWithEmptyDelimiter{arena, limit, *builder}); - } - // At this point we know the string is not empty and the delimiter is not - // empty. - std::string delimiter_scratch; - absl::string_view delimiter_view = delimiter.NativeString(delimiter_scratch); - std::string content_scratch; - absl::string_view content_view = string.NativeString(content_scratch); - while (limit > 1 && !content_view.empty()) { - auto pos = content_view.find(delimiter_view); - if (pos == absl::string_view::npos) { - break; - } - // We assume the original string was well-formed. - CEL_RETURN_IF_ERROR( - builder->Add(StringValue(arena, content_view.substr(0, pos)))); - --limit; - content_view.remove_prefix(pos + delimiter_view.size()); - if (content_view.empty()) { - // We found the delimiter at the end of the string. Add an empty string - // to the end of the list. - CEL_RETURN_IF_ERROR(builder->Add(StringValue{})); - return std::move(*builder).Build(); - } - } - // We have one left in the limit or do not have any more matches. Add - // whatever is left as the remaining entry. - // - // We assume the original string was well-formed. - CEL_RETURN_IF_ERROR(builder->Add(StringValue(arena, content_view))); - return std::move(*builder).Build(); + return string.Split(delimiter, limit, arena); } absl::StatusOr Split2( @@ -221,27 +85,7 @@ absl::StatusOr Split2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - return Split3(string, delimiter, -1, descriptor_pool, message_factory, arena); -} - -absl::StatusOr LowerAscii(const StringValue& string, - const google::protobuf::DescriptorPool* absl_nonnull, - google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull arena) { - std::string content = string.NativeString(); - absl::AsciiStrToLower(&content); - // We assume the original string was well-formed. - return StringValue(arena, std::move(content)); -} - -absl::StatusOr UpperAscii(const StringValue& string, - const google::protobuf::DescriptorPool* absl_nonnull, - google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull arena) { - std::string content = string.NativeString(); - absl::AsciiStrToUpper(&content); - // We assume the original string was well-formed. - return StringValue(arena, std::move(content)); + return string.Split(delimiter, arena); } absl::StatusOr Replace2(const StringValue& string, @@ -250,38 +94,7 @@ absl::StatusOr Replace2(const StringValue& string, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { - if (limit == 0) { - // When the replacement limit is 0, the result is the original string. - return string; - } - if (limit < 0) { - // Per spec, when limit is negative treat is as unlimited. - limit = std::numeric_limits::max(); - } - - std::string result; - std::string old_sub_scratch; - absl::string_view old_sub_view = old_sub.NativeString(old_sub_scratch); - std::string new_sub_scratch; - absl::string_view new_sub_view = new_sub.NativeString(new_sub_scratch); - std::string content_scratch; - absl::string_view content_view = string.NativeString(content_scratch); - while (limit > 0 && !content_view.empty()) { - auto pos = content_view.find(old_sub_view); - if (pos == absl::string_view::npos) { - break; - } - result.append(content_view.substr(0, pos)); - result.append(new_sub_view); - --limit; - content_view.remove_prefix(pos + old_sub_view.size()); - } - // Add the remainder of the string. - if (!content_view.empty()) { - result.append(content_view); - } - - return StringValue(arena, std::move(result)); + return string.Replace(old_sub, new_sub, limit, arena); } absl::StatusOr Replace1( @@ -290,8 +103,67 @@ absl::StatusOr Replace1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - return Replace2(string, old_sub, new_sub, -1, descriptor_pool, - message_factory, arena); + return string.Replace(old_sub, new_sub, -1, arena); +} + +Value CharAt(const StringValue& string, int64_t pos) { + return string.CharAt(pos); +} + +int64_t IndexOf2(const StringValue& haystack, const StringValue& needle) { + return haystack.IndexOf(needle); +} + +Value IndexOf3(const StringValue& haystack, const StringValue& needle, + int64_t pos) { + return haystack.IndexOf(needle, pos); +} + +int64_t LastIndexOf2(const StringValue& haystack, const StringValue& needle) { + return haystack.LastIndexOf(needle); +} + +Value LastIndexOf3(const StringValue& haystack, const StringValue& needle, + int64_t pos) { + return haystack.LastIndexOf(needle, pos); +} + +Value Substring2(const StringValue& string, int64_t start) { + return string.Substring(start); +} + +Value Substring3(const StringValue& string, int64_t start, int64_t end) { + return string.Substring(start, end); +} + +StringValue Trim(const StringValue& string) { return string.Trim(); } + +StringValue LowerAscii(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.LowerAscii(arena); +} + +StringValue UpperAscii(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.UpperAscii(arena); +} + +StringValue Quote(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.Quote(arena); +} + +StringValue Reverse(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.Reverse(arena); } const Type& ListStringType() { @@ -391,6 +263,11 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder) { MakeMemberOverloadDecl("string_reverse", StringType(), StringType()))); + CEL_ASSIGN_OR_RETURN( + auto trim_decl, + MakeFunctionDecl("trim", MakeMemberOverloadDecl( + "string_trim", StringType(), StringType()))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(char_at_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(index_of_decl))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(last_index_of_decl))); @@ -401,6 +278,7 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder) { // MergeFunction is used to combine with the reverse function // defined in cel.lib.ext.lists extension. CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl))); + CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(trim_decl))); return absl::OkStatus(); } @@ -453,6 +331,49 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("indexOf", + &IndexOf2, + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter::RegisterMemberOverload("indexOf", + &IndexOf3, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", + &LastIndexOf2, + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", + &LastIndexOf3, + registry))); + CEL_RETURN_IF_ERROR( + (BinaryFunctionAdapter::RegisterMemberOverload("substring", + &Substring2, + registry))); + CEL_RETURN_IF_ERROR( + (TernaryFunctionAdapter::RegisterMemberOverload("substring", + &Substring3, + registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterMemberOverload( + "trim", &Trim, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "strings.quote", &Quote, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterMemberOverload( + "reverse", &Reverse, registry))); return absl::OkStatus(); } diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index e2eb5e71f..714c2a3a7 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -15,15 +15,18 @@ #include "extensions/strings.h" #include +#include #include #include "cel/expr/syntax.pb.h" +#include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/decl.h" +#include "common/type.h" #include "common/value.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" @@ -280,6 +283,104 @@ TEST(StringsCheckerLibrary, SmokeTest) { )~bool^equals)"); } +using StringsExtFunctionsTest = testing::TestWithParam; + +TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) { + const std::string& expr = GetParam(); + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); + + auto result = compiler->Compile(expr, ""); + + ASSERT_THAT(result, IsOk()); + ASSERT_TRUE(result->IsValid()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT( + RegisterStringsFunctions(runtime_builder.function_registry(), opts), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result->ReleaseAst())); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); +} + +INSTANTIATE_TEST_SUITE_P( + StringsExtMacrosParamsTest, StringsExtFunctionsTest, + testing::Values( + // Tests for charAt() + "'tacocat'.charAt(3) == 'o'", "'tacocat'.charAt(7) == ''", + "'©αT'.charAt(0) == '©' && '©αT'.charAt(1) == 'α' && '©αT'.charAt(2) " + "== 'T'", + + // Tests for indexOf() + "'tacocat'.indexOf('') == 0", "'tacocat'.indexOf('ac') == 1", + "'tacocat'.indexOf('none') == -1", "'tacocat'.indexOf('', 3) == 3", + "'tacocat'.indexOf('a', 3) == 5", "'tacocat'.indexOf('at', 3) == 5", + "'ta©o©αT'.indexOf('©') == 2", "'ta©o©αT'.indexOf('©', 3) == 4", + "'ta©o©αT'.indexOf('©αT', 3) == 4", "'ta©o©αT'.indexOf('©α', 5) == -1", + "'ijk'.indexOf('k') == 2", "'hello wello'.indexOf('hello wello') == 0", + "'hello wello'.indexOf('ello', 6) == 7", + "'hello wello'.indexOf('elbo room!!') == -1", + "'hello wello'.indexOf('elbo room!!!') == -1", + "''.lastIndexOf('@@') == -1", "'tacocat'.lastIndexOf('') == 7", + "'tacocat'.lastIndexOf('at') == 5", + "'tacocat'.lastIndexOf('none') == -1", + "'tacocat'.lastIndexOf('', 3) == 3", + "'tacocat'.lastIndexOf('a', 3) == 1", "'ta©o©αT'.lastIndexOf('©') == 4", + "'ta©o©αT'.lastIndexOf('©', 3) == 2", + "'ta©o©αT'.lastIndexOf('©α', 4) == 4", + "'hello wello'.lastIndexOf('ello', 6) == 1", + "'hello wello'.lastIndexOf('low') == -1", + "'hello wello'.lastIndexOf('elbo room!!') == -1", + "'hello wello'.lastIndexOf('elbo room!!!') == -1", + "'hello wello'.lastIndexOf('hello wello') == 0", + "'bananananana'.lastIndexOf('nana', 7) == 6", + + // Tests for substring() + "'tacocat'.substring(4) == 'cat'", "'tacocat'.substring(7) == ''", + "'tacocat'.substring(0, 4) == 'taco'", + "'tacocat'.substring(4, 4) == ''", + "'ta©o©αT'.substring(2, 6) == '©o©α'", + "'ta©o©αT'.substring(7, 7) == ''", + + // Tests for strings.quote() + R"(strings.quote("first\nsecond") == "\"first\\nsecond\"")", + R"(strings.quote("bell\a") == "\"bell\\a\"")", + R"(strings.quote("\bbackspace") == "\"\\bbackspace\"")", + R"(strings.quote("\fform feed") == "\"\\fform feed\"")", + R"(strings.quote("carriage \r return") == "\"carriage \\r return\"")", + R"(strings.quote("vertical \v tab") == "\"vertical \\v tab\"")", + R"(strings.quote("verbatim") == "\"verbatim\"")", + R"(strings.quote("ends with \\") == "\"ends with \\\\\"")", + R"(strings.quote("\\ starts with") == "\"\\\\ starts with\"")", + + // Tests for trim() + R"(' \f\n\r\t\vtext '.trim() == 'text')", + R"('\u0085\u00a0\u1680text'.trim() == 'text')", + R"('text\u2000\u2001\u2002\u2003\u2004\u2004\u2006\u2007\u2008\u2009'.trim() == 'text')", + R"('\u200atext\u2028\u2029\u202F\u205F\u3000'.trim() == 'text')", + R"(' hello world '.trim() == 'hello world')")); + // Basic test for the included declarations. // Additional coverage for behavior in the spec tests. class StringsCheckerLibraryTest : public ::testing::TestWithParam { @@ -314,7 +415,42 @@ INSTANTIATE_TEST_SUITE_P( "'tacocat'.substring(1) == 'acocat'", "'tacocat'.substring(1, 3) == 'aco'", "'aBc'.upperAscii() == 'ABC'", "'abc %d'.format([2]) == 'abc 2'", - "strings.quote('abc') == \"'abc 2'\"", "'abc'.reverse() == 'cba'")); + "strings.quote('abc') == \"'abc 2'\"", "'abc'.reverse() == 'cba'", + "'ta©o©αT'.substring(7, 7) == ''")); + +class StringsRuntimeErrorTest : public ::testing::TestWithParam {}; + +TEST_P(StringsRuntimeErrorTest, RuntimeTests) { + const std::string& expr_string = GetParam(); + const auto options = RuntimeOptions{}; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + EXPECT_THAT(RegisterStringsFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(expr_string, "", ParserOptions{})); + + EXPECT_THAT( + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + absl_testing::StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("No overloads provided"))); +} + +INSTANTIATE_TEST_SUITE_P( + TypeErrors, StringsRuntimeErrorTest, + Values( + // string_ext.type_errors/indexof_ternary_invalid_arguments + "'42'.indexOf('4', 0, 1) == 0", + // string_ext.type_errors/replace_quaternary_invalid_argument + "'42'.replace('2', '1', 1, false) == '41'", + // string_ext.type_errors/split_ternary_invalid_argument + "'42'.split('2', 1, 1) == ['4']", + // string_ext.type_errors/substring_ternary_invalid_argument + "'hello'.substring(1, 2, 3) == ''")); } // namespace } // namespace cel::extensions