Skip to content

Commit 4894506

Browse files
committed
Address comment: fix hash aggregate as well
1 parent 7317ced commit 4894506

3 files changed

Lines changed: 116 additions & 40 deletions

File tree

cpp/src/arrow/acero/hash_aggregate_test.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,74 @@ TEST_P(GroupBy, MinMaxScalar) {
20002000
}
20012001
}
20022002

2003+
TEST_P(GroupBy, MinMaxWithNaN) {
2004+
auto in_schema = schema({
2005+
field("argument1", float32()),
2006+
field("argument2", float64()),
2007+
field("key", int64()),
2008+
});
2009+
for (bool use_threads : {true, false}) {
2010+
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
2011+
2012+
auto table = TableFromJSON(in_schema, {R"([
2013+
[NaN, NaN, 1],
2014+
[NaN, NaN, 2],
2015+
[NaN, NaN, 3]
2016+
])",
2017+
R"([
2018+
[NaN, NaN, 1],
2019+
[-Inf, -Inf, 2],
2020+
[Inf, Inf, 3]
2021+
])"});
2022+
2023+
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
2024+
GroupByTest(
2025+
{
2026+
table->GetColumnByName("argument1"),
2027+
table->GetColumnByName("argument1"),
2028+
table->GetColumnByName("argument1"),
2029+
table->GetColumnByName("argument2"),
2030+
table->GetColumnByName("argument2"),
2031+
table->GetColumnByName("argument2"),
2032+
},
2033+
{table->GetColumnByName("key")},
2034+
{
2035+
{"hash_min", nullptr},
2036+
{"hash_max", nullptr},
2037+
{"hash_min_max", nullptr},
2038+
{"hash_min", nullptr},
2039+
{"hash_max", nullptr},
2040+
{"hash_min_max", nullptr},
2041+
},
2042+
use_threads));
2043+
ValidateOutput(aggregated_and_grouped);
2044+
SortBy({"key_0"}, &aggregated_and_grouped);
2045+
2046+
AssertDatumsEqual(ArrayFromJSON(struct_({
2047+
field("key_0", int64()),
2048+
field("hash_min", float32()),
2049+
field("hash_max", float32()),
2050+
field("hash_min_max", struct_({
2051+
field("min", float32()),
2052+
field("max", float32()),
2053+
})),
2054+
field("hash_min", float64()),
2055+
field("hash_max", float64()),
2056+
field("hash_min_max", struct_({
2057+
field("min", float64()),
2058+
field("max", float64()),
2059+
})),
2060+
}),
2061+
R"([
2062+
[1, NaN, NaN, {"min": NaN, "max": NaN}, NaN, NaN, {"min": NaN, "max": NaN}],
2063+
[2, -Inf, -Inf, {"min": -Inf, "max": -Inf}, -Inf, -Inf, {"min": -Inf, "max": -Inf}],
2064+
[3, Inf, Inf, {"min": Inf, "max": Inf}, Inf, Inf, {"min": Inf, "max": Inf}]
2065+
])"),
2066+
aggregated_and_grouped,
2067+
/*verbose=*/true);
2068+
}
2069+
}
2070+
20032071
TEST_P(GroupBy, AnyAndAll) {
20042072
for (bool use_threads : {true, false}) {
20052073
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

cpp/src/arrow/compute/kernels/hash_aggregate.cc

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -276,46 +276,34 @@ struct AntiExtrema {
276276
static constexpr CType anti_max() { return std::numeric_limits<CType>::min(); }
277277
};
278278

279-
template <>
280-
struct AntiExtrema<bool> {
281-
static constexpr bool anti_min() { return true; }
282-
static constexpr bool anti_max() { return false; }
279+
template <CBooleanConcept CType>
280+
struct AntiExtrema<CType> {
281+
static constexpr CType anti_min() { return true; }
282+
static constexpr CType anti_max() { return false; }
283283
};
284284

285-
template <>
286-
struct AntiExtrema<float> {
287-
static constexpr float anti_min() { return std::numeric_limits<float>::infinity(); }
288-
static constexpr float anti_max() { return -std::numeric_limits<float>::infinity(); }
285+
template <CFloatingPointConcept CType>
286+
struct AntiExtrema<CType> {
287+
static constexpr CType anti_min() { return std::numeric_limits<CType>::quiet_NaN(); }
288+
static constexpr CType anti_max() { return std::numeric_limits<CType>::quiet_NaN(); }
289289
};
290290

291-
template <>
292-
struct AntiExtrema<double> {
293-
static constexpr double anti_min() { return std::numeric_limits<double>::infinity(); }
294-
static constexpr double anti_max() { return -std::numeric_limits<double>::infinity(); }
291+
template <CDecimalConcept CType>
292+
struct AntiExtrema<CType> {
293+
static constexpr CType anti_min() { return CType::GetMaxSentinel(); }
294+
static constexpr CType anti_max() { return CType::GetMinSentinel(); }
295295
};
296296

297-
template <>
298-
struct AntiExtrema<Decimal32> {
299-
static constexpr Decimal32 anti_min() { return BasicDecimal32::GetMaxSentinel(); }
300-
static constexpr Decimal32 anti_max() { return BasicDecimal32::GetMinSentinel(); }
301-
};
302-
303-
template <>
304-
struct AntiExtrema<Decimal64> {
305-
static constexpr Decimal64 anti_min() { return BasicDecimal64::GetMaxSentinel(); }
306-
static constexpr Decimal64 anti_max() { return BasicDecimal64::GetMinSentinel(); }
307-
};
308-
309-
template <>
310-
struct AntiExtrema<Decimal128> {
311-
static constexpr Decimal128 anti_min() { return BasicDecimal128::GetMaxSentinel(); }
312-
static constexpr Decimal128 anti_max() { return BasicDecimal128::GetMinSentinel(); }
297+
template <typename CType>
298+
struct MinMaxOp {
299+
static constexpr CType min(CType a, CType b) { return std::min(a, b); }
300+
static constexpr CType max(CType a, CType b) { return std::max(a, b); }
313301
};
314302

315-
template <>
316-
struct AntiExtrema<Decimal256> {
317-
static constexpr Decimal256 anti_min() { return BasicDecimal256::GetMaxSentinel(); }
318-
static constexpr Decimal256 anti_max() { return BasicDecimal256::GetMinSentinel(); }
303+
template <CFloatingPointConcept CType>
304+
struct MinMaxOp<CType> {
305+
static constexpr CType min(CType a, CType b) { return std::fmin(a, b); }
306+
static constexpr CType max(CType a, CType b) { return std::fmax(a, b); }
319307
};
320308

321309
template <typename Type, typename Enable = void>
@@ -352,8 +340,8 @@ struct GroupedMinMaxImpl final : public GroupedAggregator {
352340
VisitGroupedValues<Type>(
353341
batch,
354342
[&](uint32_t g, CType val) {
355-
GetSet::Set(raw_mins, g, std::min(GetSet::Get(raw_mins, g), val));
356-
GetSet::Set(raw_maxes, g, std::max(GetSet::Get(raw_maxes, g), val));
343+
GetSet::Set(raw_mins, g, MinMaxOp<CType>::min(GetSet::Get(raw_mins, g), val));
344+
GetSet::Set(raw_maxes, g, MinMaxOp<CType>::max(GetSet::Get(raw_maxes, g), val));
357345
bit_util::SetBit(has_values_.mutable_data(), g);
358346
},
359347
[&](uint32_t g) { bit_util::SetBit(has_nulls_.mutable_data(), g); });
@@ -373,12 +361,12 @@ struct GroupedMinMaxImpl final : public GroupedAggregator {
373361
auto g = group_id_mapping.GetValues<uint32_t>(1);
374362
for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < group_id_mapping.length;
375363
++other_g, ++g) {
376-
GetSet::Set(
377-
raw_mins, *g,
378-
std::min(GetSet::Get(raw_mins, *g), GetSet::Get(other_raw_mins, other_g)));
379-
GetSet::Set(
380-
raw_maxes, *g,
381-
std::max(GetSet::Get(raw_maxes, *g), GetSet::Get(other_raw_maxes, other_g)));
364+
GetSet::Set(raw_mins, *g,
365+
MinMaxOp<CType>::min(GetSet::Get(raw_mins, *g),
366+
GetSet::Get(other_raw_mins, other_g)));
367+
GetSet::Set(raw_maxes, *g,
368+
MinMaxOp<CType>::max(GetSet::Get(raw_maxes, *g),
369+
GetSet::Get(other_raw_maxes, other_g)));
382370

383371
if (bit_util::GetBit(other->has_values_.data(), other_g)) {
384372
bit_util::SetBit(has_values_.mutable_data(), *g);

cpp/src/arrow/type_traits.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#pragma once
1919

20+
#include <concepts>
2021
#include <memory>
2122
#include <string>
2223
#include <type_traits>
@@ -1837,4 +1838,23 @@ constexpr bool is_union(const DataType& type) { return is_union(type.id()); }
18371838

18381839
/// @}
18391840

1841+
/// \addtogroup c-type-concepts
1842+
/// @{
1843+
1844+
// XXX: To be completed with more concepts as needed.
1845+
1846+
template <typename T>
1847+
concept CBooleanConcept = std::is_same_v<T, bool>;
1848+
1849+
// XXX: Ideally we want to have std::floating_point<Float16> = true.
1850+
template <typename T>
1851+
concept CFloatingPointConcept =
1852+
std::floating_point<T> || std::is_same_v<T, util::Float16>;
1853+
1854+
template <typename T>
1855+
concept CDecimalConcept = std::is_same_v<T, Decimal32> || std::is_same_v<T, Decimal64> ||
1856+
std::is_same_v<T, Decimal128> || std::is_same_v<T, Decimal256>;
1857+
1858+
/// @}
1859+
18401860
} // namespace arrow

0 commit comments

Comments
 (0)