From 5fb448ccd772eeb13332e5d7a4585e98f74e4936 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Mon, 9 Jun 2025 23:22:01 +0800 Subject: [PATCH] [llvm][ADT] Add wrappers to `std::includes` Add `llvm::includes` that accepts a range instead of start/end iterator. --- llvm/include/llvm/ADT/STLExtras.h | 22 ++++++++++++++++++++++ llvm/unittests/ADT/STLExtrasTest.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 897dc76a420b2..eea06cfb99ba2 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1940,6 +1940,28 @@ template bool is_sorted(R &&Range) { return std::is_sorted(adl_begin(Range), adl_end(Range)); } +/// Provide wrappers to std::includes which take ranges instead of having to +/// pass begin/end explicitly. +/// This function checks if the sorted range \p R2 is a subsequence of the +/// sorted range \p R1. The ranges must be sorted in non-descending order. +template bool includes(R1 &&Range1, R2 &&Range2) { + assert(is_sorted(Range1) && "Range1 must be sorted in non-descending order"); + assert(is_sorted(Range2) && "Range2 must be sorted in non-descending order"); + return std::includes(adl_begin(Range1), adl_end(Range1), adl_begin(Range2), + adl_end(Range2)); +} + +/// This function checks if the sorted range \p R2 is a subsequence of the +/// sorted range \p R1. The ranges must be sorted with respect to a comparator +/// \p C. +template +bool includes(R1 &&Range1, R2 &&Range2, Compare &&C) { + assert(is_sorted(Range1, C) && "Range1 must be sorted with respect to C"); + assert(is_sorted(Range2, C) && "Range2 must be sorted with respect to C"); + return std::includes(adl_begin(Range1), adl_end(Range1), adl_begin(Range2), + adl_end(Range2), std::forward(C)); +} + /// Wrapper function around std::count to count the number of times an element /// \p Element occurs in the given range \p Range. template auto count(R &&Range, const E &Element) { diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index 0101be47a6869..286cfa745fd14 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -1567,6 +1567,30 @@ TEST(STLExtrasTest, Mismatch) { } } +TEST(STLExtrasTest, Includes) { + { + std::vector V1 = {1, 2}; + std::vector V2; + EXPECT_TRUE(includes(V1, V2)); + EXPECT_FALSE(includes(V2, V1)); + V2.push_back(1); + EXPECT_TRUE(includes(V1, V2)); + V2.push_back(3); + EXPECT_FALSE(includes(V1, V2)); + } + + { + std::vector V1 = {3, 2, 1}; + std::vector V2; + EXPECT_TRUE(includes(V1, V2, std::greater<>())); + EXPECT_FALSE(includes(V2, V1, std::greater<>())); + V2.push_back(3); + EXPECT_TRUE(includes(V1, V2, std::greater<>())); + V2.push_back(0); + EXPECT_FALSE(includes(V1, V2, std::greater<>())); + } +} + struct Foo; struct Bar {};