Skip to content

Commit f9283d0

Browse files
authored
apacheGH-45572: [C++][Compute] Add rank_normal function (apache#45573)
### Rationale for this change Computing ranks as values of the "probit" function (https://en.wikipedia.org/wiki/Probit), rather than quantiles between 0 and 1, can be useful for machine learning and other tasks. ### What changes are included in this PR? Add a "rank_normal" function that computes array ranks as points on the normal distribution. It is similar to calling the "rank_quantile" function and then applying the normal percent-point function ("probit"). ### Are these changes tested? Yes, by dedicated unit tests. ### Are there any user-facing changes? No, except a new compute function. * GitHub Issue: apache#45572 Authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 01e3f1e commit f9283d0

10 files changed

Lines changed: 461 additions & 38 deletions

File tree

LICENSE.txt

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2257,5 +2257,36 @@ SOFTWARE.
22572257
java/vector/src/main/java/org/apache/arrow/vector/util/IntObjectHashMap.java
22582258
java/vector/src/main/java/org/apache/arrow/vector/util/IntObjectMap.java
22592259

2260-
These file are derived from code from Netty, which is made available under the
2260+
These files are derived from code from Netty, which is made available under the
22612261
Apache License 2.0.
2262+
2263+
--------------------------------------------------------------------------------
2264+
cpp/src/arrow/util/math_internal.cc (some portions)
2265+
2266+
Some portions of this file are derived from
2267+
2268+
https://github.com/ankane/dist-rust/
2269+
2270+
which is made available under the MIT license
2271+
2272+
The MIT License (MIT)
2273+
2274+
Copyright (c) 2021-2023 Contributors
2275+
2276+
Permission is hereby granted, free of charge, to any person obtaining a copy
2277+
of this software and associated documentation files (the "Software"), to deal
2278+
in the Software without restriction, including without limitation the rights
2279+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
2280+
copies of the Software, and to permit persons to whom the Software is
2281+
furnished to do so, subject to the following conditions:
2282+
2283+
The above copyright notice and this permission notice shall be included in
2284+
all copies or substantial portions of the Software.
2285+
2286+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
2287+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
2288+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
2289+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
2290+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2291+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2292+
THE SOFTWARE.

cpp/src/arrow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ set(ARROW_UTIL_SRCS
529529
util/logger.cc
530530
util/logging.cc
531531
util/key_value_metadata.cc
532+
util/math_internal.cc
532533
util/memory.cc
533534
util/mutex.cc
534535
util/ree_util.cc

cpp/src/arrow/compute/kernels/vector_rank.cc

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "arrow/compute/function.h"
2222
#include "arrow/compute/kernels/vector_sort_internal.h"
2323
#include "arrow/compute/registry.h"
24+
#include "arrow/util/math_internal.h"
2425

2526
namespace arrow::compute::internal {
2627

@@ -62,16 +63,6 @@ void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_sel
6263
}
6364
}
6465

65-
const RankOptions* GetDefaultRankOptions() {
66-
static const auto kDefaultRankOptions = RankOptions::Defaults();
67-
return &kDefaultRankOptions;
68-
}
69-
70-
const RankQuantileOptions* GetDefaultQuantileRankOptions() {
71-
static const auto kDefaultQuantileRankOptions = RankQuantileOptions::Defaults();
72-
return &kDefaultQuantileRankOptions;
73-
}
74-
7566
template <typename ArrowType>
7667
Result<NullPartitionResult> DoSortAndMarkDuplicate(
7768
ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const Array& input,
@@ -164,8 +155,9 @@ class SortAndMarkDuplicate : public TypeVisitor {
164155
NullPartitionResult sorted_{};
165156
};
166157

167-
// A helper class that emits rankings for the "rank_quantile" function
168-
struct QuantileRanker {
158+
// A CRTP-based helper class for "rank_normal" and "rank_quantile"
159+
template <typename Derived>
160+
struct BaseQuantileRanker {
169161
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) {
170162
const int64_t length = sorted.overall_end() - sorted.overall_begin();
171163
ARROW_ASSIGN_OR_RAISE(auto rankings,
@@ -187,10 +179,11 @@ struct QuantileRanker {
187179
}
188180
// The run length, i.e. the frequency of the current value
189181
int64_t freq = run_end - it;
190-
double quantile = (cum_freq + 0.5 * freq) / static_cast<double>(length);
182+
const double quantile = (cum_freq + 0.5 * freq) / static_cast<double>(length);
183+
const double value = Derived::TransformValue(quantile);
191184
// Output quantile rank values
192185
for (; it < run_end; ++it) {
193-
out_begin[original_index(*it)] = quantile;
186+
out_begin[original_index(*it)] = value;
194187
}
195188
cum_freq += freq;
196189
}
@@ -199,6 +192,18 @@ struct QuantileRanker {
199192
}
200193
};
201194

195+
// A derived class that emits rankings for the "rank_quantile" function
196+
struct QuantileRanker : public BaseQuantileRanker<QuantileRanker> {
197+
static double TransformValue(double quantile) { return quantile; }
198+
};
199+
200+
// A derived class that emits rankings for the "rank_normal" function
201+
struct NormalRanker : public BaseQuantileRanker<NormalRanker> {
202+
static double TransformValue(double quantile) {
203+
return ::arrow::internal::NormalPPF(quantile);
204+
}
205+
};
206+
202207
// A helper class that emits rankings for the "rank" function
203208
struct OrdinalRanker {
204209
explicit OrdinalRanker(RankOptions::Tiebreaker tiebreaker) : tiebreaker_(tiebreaker) {}
@@ -294,6 +299,20 @@ const FunctionDoc rank_quantile_doc(
294299
"The handling of nulls and NaNs can be changed in RankQuantileOptions."),
295300
{"input"}, "RankQuantileOptions");
296301

302+
const FunctionDoc rank_normal_doc(
303+
"Compute normal (gaussian) ranks of an array",
304+
("This function computes a normal (gaussian) rank of the input array.\n"
305+
"By default, null values are considered greater than any other value and\n"
306+
"are therefore sorted at the end of the input. For floating-point types,\n"
307+
"NaNs are considered greater than any other non-null value, but smaller\n"
308+
"than null values.\n"
309+
"The results are finite real values. They are obtained as if first\n"
310+
"calling the \"rank_quantile\" function and then applying the normal\n"
311+
"percent-point function (PPF) to the resulting quantile values.\n"
312+
"\n"
313+
"The handling of nulls and NaNs can be changed in RankQuantileOptions."),
314+
{"input"}, "RankQuantileOptions");
315+
297316
template <typename Derived>
298317
class RankMetaFunctionBase : public MetaFunction {
299318
public:
@@ -361,11 +380,14 @@ class RankMetaFunction : public RankMetaFunctionBase<RankMetaFunction> {
361380
}
362381

363382
RankMetaFunction()
364-
: RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {}
383+
: RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, &kDefaultOptions) {}
384+
385+
static inline const auto kDefaultOptions = RankOptions::Defaults();
365386
};
366387

367388
class RankQuantileMetaFunction : public RankMetaFunctionBase<RankQuantileMetaFunction> {
368389
public:
390+
using Base = RankMetaFunctionBase<RankQuantileMetaFunction>;
369391
using FunctionOptionsType = RankQuantileOptions;
370392
using RankerType = QuantileRanker;
371393

@@ -375,14 +397,34 @@ class RankQuantileMetaFunction : public RankMetaFunctionBase<RankQuantileMetaFun
375397

376398
RankQuantileMetaFunction()
377399
: RankMetaFunctionBase("rank_quantile", Arity::Unary(), rank_quantile_doc,
378-
GetDefaultQuantileRankOptions()) {}
400+
&kDefaultOptions) {}
401+
402+
static inline const auto kDefaultOptions = RankQuantileOptions::Defaults();
403+
};
404+
405+
class RankNormalMetaFunction : public RankMetaFunctionBase<RankNormalMetaFunction> {
406+
public:
407+
using Base = RankMetaFunctionBase<RankQuantileMetaFunction>;
408+
using FunctionOptionsType = RankQuantileOptions;
409+
using RankerType = NormalRanker;
410+
411+
static bool NeedsDuplicates(const RankQuantileOptions&) { return true; }
412+
413+
static RankerType GetRanker(const RankQuantileOptions& options) { return RankerType(); }
414+
415+
RankNormalMetaFunction()
416+
: RankMetaFunctionBase("rank_normal", Arity::Unary(), rank_normal_doc,
417+
&kDefaultOptions) {}
418+
419+
static inline const auto kDefaultOptions = RankQuantileOptions::Defaults();
379420
};
380421

381422
} // namespace
382423

383424
void RegisterVectorRank(FunctionRegistry* registry) {
384425
DCHECK_OK(registry->AddFunction(std::make_shared<RankMetaFunction>()));
385426
DCHECK_OK(registry->AddFunction(std::make_shared<RankQuantileMetaFunction>()));
427+
DCHECK_OK(registry->AddFunction(std::make_shared<RankNormalMetaFunction>()));
386428
}
387429

388430
} // namespace arrow::compute::internal

cpp/src/arrow/compute/kernels/vector_sort_test.cc

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,7 +2205,7 @@ TEST_F(TestNestedSortIndices, SortRecordBatch) { TestSort(GetRecordBatch()); }
22052205
TEST_F(TestNestedSortIndices, SortTable) { TestSort(GetTable()); }
22062206

22072207
// ----------------------------------------------------------------------
2208-
// Tests for Rank and Quantile Rank
2208+
// Tests for Rank, Quantile Rank and Normal Rank
22092209

22102210
class BaseTestRank : public ::testing::Test {
22112211
protected:
@@ -2471,43 +2471,84 @@ TEST_F(TestRank, EmptyChunks) {
24712471

24722472
class TestRankQuantile : public BaseTestRank {
24732473
public:
2474-
void AssertRankQuantile(const DatumVector& datums, SortOrder order,
2475-
NullPlacement null_placement,
2476-
const std::shared_ptr<Array>& expected) {
2474+
void AssertRankQuantileGeneric(const std::string& function_name,
2475+
const DatumVector& datums, SortOrder order,
2476+
NullPlacement null_placement,
2477+
const std::shared_ptr<Array>& expected) {
2478+
ARROW_SCOPED_TRACE("function = ", function_name);
24772479
const std::vector<SortKey> sort_keys{SortKey("foo", order)};
24782480
RankQuantileOptions options(sort_keys, null_placement);
24792481
ARROW_SCOPED_TRACE("options = ", options.ToString());
24802482
for (const auto& datum : datums) {
2481-
ASSERT_OK_AND_ASSIGN(auto actual, CallFunction("rank_quantile", {datum}, &options));
2483+
ASSERT_OK_AND_ASSIGN(auto actual, CallFunction(function_name, {datum}, &options));
24822484
ValidateOutput(actual);
2483-
AssertDatumsEqual(expected, actual, /*verbose=*/true);
2485+
if (function_name == "rank_normal") {
2486+
// Normal PPF results can only be approximate
2487+
auto equal_options = EqualOptions().atol(1e-8);
2488+
AssertDatumsApproxEqual(expected, actual, /*verbose=*/true, equal_options);
2489+
} else {
2490+
AssertDatumsEqual(expected, actual, /*verbose=*/true);
2491+
}
24842492
}
24852493
}
24862494

2487-
void AssertRankQuantile(const DatumVector& datums, SortOrder order,
2488-
NullPlacement null_placement, const std::string& expected) {
2489-
AssertRankQuantile(datums, order, null_placement, ArrayFromJSON(float64(), expected));
2495+
void AssertRankQuantileGeneric(const std::string& function_name, const Datum& datum,
2496+
SortOrder order, NullPlacement null_placement,
2497+
const std::shared_ptr<Array>& expected) {
2498+
AssertRankQuantileGeneric(function_name, DatumVector{datum}, order, null_placement,
2499+
expected);
24902500
}
24912501

2492-
void AssertRankQuantile(SortOrder order, NullPlacement null_placement,
2493-
const std::shared_ptr<Array>& expected) {
2494-
AssertRankQuantile(datums_, order, null_placement, expected);
2502+
void AssertRankQuantileGeneric(const std::string& function_name,
2503+
const DatumVector& datums, SortOrder order,
2504+
NullPlacement null_placement,
2505+
const std::string& expected) {
2506+
AssertRankQuantileGeneric(function_name, datums, order, null_placement,
2507+
ArrayFromJSON(float64(), expected));
24952508
}
24962509

2497-
void AssertRankQuantile(SortOrder order, NullPlacement null_placement,
2498-
const std::string& expected) {
2499-
AssertRankQuantile(datums_, order, null_placement,
2500-
ArrayFromJSON(float64(), expected));
2510+
void AssertRankQuantileGeneric(const std::string& function_name, const Datum& datum,
2511+
SortOrder order, NullPlacement null_placement,
2512+
const std::string& expected) {
2513+
AssertRankQuantileGeneric(function_name, DatumVector{datum}, order, null_placement,
2514+
ArrayFromJSON(float64(), expected));
2515+
}
2516+
2517+
void AssertRankQuantileGeneric(const std::string& function_name, SortOrder order,
2518+
NullPlacement null_placement,
2519+
const std::shared_ptr<Array>& expected) {
2520+
AssertRankQuantileGeneric(function_name, datums_, order, null_placement, expected);
2521+
}
2522+
2523+
void AssertRankQuantileGeneric(const std::string& function_name, SortOrder order,
2524+
NullPlacement null_placement,
2525+
const std::string& expected) {
2526+
AssertRankQuantileGeneric(function_name, datums_, order, null_placement,
2527+
ArrayFromJSON(float64(), expected));
2528+
}
2529+
2530+
template <typename... Args>
2531+
void AssertRankQuantile(Args&&... args) {
2532+
AssertRankQuantileGeneric("rank_quantile", std::forward<Args>(args)...);
2533+
}
2534+
2535+
template <typename... Args>
2536+
void AssertRankNormal(Args&&... args) {
2537+
AssertRankQuantileGeneric("rank_normal", std::forward<Args>(args)...);
25012538
}
25022539

25032540
void AssertRankQuantileEmpty(std::shared_ptr<DataType> type) {
25042541
for (auto null_placement : AllNullPlacements()) {
25052542
for (auto order : AllOrders()) {
2506-
AssertRankQuantile({ArrayFromJSON(type, "[]")}, order, null_placement, "[]");
2507-
AssertRankQuantile({ArrayFromJSON(type, "[null]")}, order, null_placement,
2508-
"[0.5]");
2509-
AssertRankQuantile({ArrayFromJSON(type, "[null, null, null]")}, order,
2543+
AssertRankQuantile(ArrayFromJSON(type, "[]"), order, null_placement, "[]");
2544+
AssertRankQuantile(ArrayFromJSON(type, "[null]"), order, null_placement, "[0.5]");
2545+
AssertRankQuantile(ArrayFromJSON(type, "[null, null, null]"), order,
25102546
null_placement, "[0.5, 0.5, 0.5]");
2547+
2548+
AssertRankNormal(ArrayFromJSON(type, "[]"), order, null_placement, "[]");
2549+
AssertRankNormal(ArrayFromJSON(type, "[null]"), order, null_placement, "[0.0]");
2550+
AssertRankNormal(ArrayFromJSON(type, "[null, null, null]"), order, null_placement,
2551+
"[0.0, 0.0, 0.0]");
25112552
}
25122553
}
25132554
}
@@ -2519,6 +2560,12 @@ class TestRankQuantile : public BaseTestRank {
25192560
"[0.3, 0.8, 0.3, 0.8, 0.3]");
25202561
AssertRankQuantile(SortOrder::Descending, null_placement,
25212562
"[0.7, 0.2, 0.7, 0.2, 0.7]");
2563+
AssertRankNormal(SortOrder::Ascending, null_placement,
2564+
"[-0.5244005127080409, 0.8416212335729143, -0.5244005127080409, "
2565+
"0.8416212335729143, -0.5244005127080409]");
2566+
AssertRankNormal(SortOrder::Descending, null_placement,
2567+
"[0.5244005127080407, -0.8416212335729142, 0.5244005127080407, "
2568+
"-0.8416212335729142, 0.5244005127080407]");
25222569
}
25232570
}
25242571

@@ -2532,6 +2579,19 @@ class TestRankQuantile : public BaseTestRank {
25322579
"[0.3, 0.9, 0.3, 0.7, 0.3]");
25332580
AssertRankQuantile(SortOrder::Descending, NullPlacement::AtEnd,
25342581
"[0.7, 0.3, 0.7, 0.1, 0.7]");
2582+
2583+
AssertRankNormal(SortOrder::Ascending, NullPlacement::AtStart,
2584+
"[-0.5244005127080409, 0.5244005127080407, -0.5244005127080409, "
2585+
"1.2815515655446004, -0.5244005127080409]");
2586+
AssertRankNormal(SortOrder::Ascending, NullPlacement::AtEnd,
2587+
"[0.5244005127080407, -1.2815515655446004, 0.5244005127080407, "
2588+
"-0.5244005127080409, 0.5244005127080407]");
2589+
AssertRankNormal(SortOrder::Descending, NullPlacement::AtStart,
2590+
"[-0.5244005127080409, 1.2815515655446004, -0.5244005127080409, "
2591+
"0.5244005127080407, -0.5244005127080409]");
2592+
AssertRankNormal(SortOrder::Descending, NullPlacement::AtEnd,
2593+
"[0.5244005127080407, -0.5244005127080409, 0.5244005127080407, "
2594+
"-1.2815515655446004, 0.5244005127080407]");
25352595
}
25362596

25372597
void AssertRankQuantileNumeric(std::shared_ptr<DataType> type) {
@@ -2545,6 +2605,17 @@ class TestRankQuantile : public BaseTestRank {
25452605
"[0.95, 0.8, 0.8, 0.6, 0.6, 0.35, 0.35, 0.35, 0.15, 0.05]");
25462606
AssertRankQuantile(SortOrder::Descending, null_placement,
25472607
"[0.05, 0.2, 0.2, 0.4, 0.4, 0.65, 0.65, 0.65, 0.85, 0.95]");
2608+
2609+
AssertRankNormal(SortOrder::Ascending, null_placement,
2610+
"[1.6448536269514722, 0.8416212335729143, 0.8416212335729143, "
2611+
"0.2533471031357997, 0.2533471031357997, -0.38532046640756773, "
2612+
"-0.38532046640756773, -0.38532046640756773, -1.0364333894937898, "
2613+
"-1.6448536269514729]");
2614+
AssertRankNormal(SortOrder::Descending, null_placement,
2615+
"[-1.6448536269514729, -0.8416212335729142, -0.8416212335729142, "
2616+
"-0.2533471031357997, -0.2533471031357997, 0.38532046640756773, "
2617+
"0.38532046640756773, 0.38532046640756773, 1.0364333894937898, "
2618+
"1.6448536269514722]");
25482619
}
25492620

25502621
// With nulls

cpp/src/arrow/util/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ add_arrow_test(utility-test
6666
list_util_test.cc
6767
logger_test.cc
6868
logging_test.cc
69+
math_test.cc
6970
queue_test.cc
7071
range_test.cc
7172
ree_util_test.cc

0 commit comments

Comments
 (0)