Skip to content

Commit b062998

Browse files
dinsepitrou
andauthored
GH-46860: [C++] Making HalfFloatBuilder accept Float16 as well as uint16_t (#46981)
### Rationale for this change #46860 Adding convenience methods for appending and retrieving Float16 to HalfFloatBuilder. ### What changes are included in this PR? HalfFloatBuilder has functions overloaded to accept Float16, tests, and documentation. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: #46860 Lead-authored-by: Eric Dinse <293818+dinse@users.noreply.github.com> Co-authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 6c9e30b commit b062998

4 files changed

Lines changed: 181 additions & 2 deletions

File tree

cpp/src/arrow/array/array_test.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "arrow/array/builder_binary.h"
3939
#include "arrow/array/builder_decimal.h"
4040
#include "arrow/array/builder_dict.h"
41+
#include "arrow/array/builder_primitive.h"
4142
#include "arrow/array/builder_run_end.h"
4243
#include "arrow/array/builder_time.h"
4344
#include "arrow/array/data.h"
@@ -60,6 +61,7 @@
6061
#include "arrow/util/bitmap_builders.h"
6162
#include "arrow/util/checked_cast.h"
6263
#include "arrow/util/decimal.h"
64+
#include "arrow/util/float16.h"
6365
#include "arrow/util/key_value_metadata.h"
6466
#include "arrow/util/macros.h"
6567
#include "arrow/util/range.h"
@@ -72,6 +74,7 @@ namespace arrow {
7274

7375
using internal::checked_cast;
7476
using internal::checked_pointer_cast;
77+
using util::Float16;
7578

7679
class TestArray : public ::testing::Test {
7780
public:
@@ -4099,4 +4102,73 @@ TYPED_TEST(TestPrimitiveArray, IndexOperator) {
40994102
}
41004103
}
41014104

4105+
class TestHalfFloatBuilder : public ::testing::Test {
4106+
public:
4107+
void VerifyValue(const HalfFloatBuilder& builder, int64_t index, float expected) {
4108+
ASSERT_EQ(builder.GetValue(index), Float16(expected).bits());
4109+
ASSERT_EQ(builder.GetValue<Float16>(index), Float16(expected));
4110+
ASSERT_EQ(builder.GetValue<uint16_t>(index), Float16(expected).bits());
4111+
ASSERT_EQ(builder[index], Float16(expected).bits());
4112+
}
4113+
};
4114+
4115+
TEST_F(TestHalfFloatBuilder, TestAppend) {
4116+
HalfFloatBuilder builder;
4117+
ASSERT_OK(builder.Append(Float16(0.0f)));
4118+
ASSERT_OK(builder.Append(Float16(1.0f).bits()));
4119+
ASSERT_OK(builder.AppendNull());
4120+
ASSERT_OK(builder.Reserve(3));
4121+
builder.UnsafeAppend(Float16(3.0f));
4122+
builder.UnsafeAppend(Float16(4.0f).bits());
4123+
builder.UnsafeAppend(uint16_t{15872}); // 1.5f
4124+
4125+
VerifyValue(builder, 0, 0.0f);
4126+
VerifyValue(builder, 1, 1.0f);
4127+
VerifyValue(builder, 3, 3.0f);
4128+
VerifyValue(builder, 4, 4.0f);
4129+
VerifyValue(builder, 5, 1.5f);
4130+
}
4131+
4132+
TEST_F(TestHalfFloatBuilder, TestBulkAppend) {
4133+
HalfFloatBuilder builder;
4134+
4135+
ASSERT_OK(builder.AppendValues(5, Float16(1.5)));
4136+
uint16_t val = Float16(2.0f).bits();
4137+
ASSERT_OK(builder.AppendValues({val, val, val, val}, {0, 1, 0, 1}));
4138+
ASSERT_EQ(builder.length(), 9);
4139+
for (int i = 0; i < 5; i++) {
4140+
VerifyValue(builder, i, 1.5f);
4141+
}
4142+
4143+
{
4144+
ASSERT_OK_AND_ASSIGN(auto array, builder.Finish());
4145+
ASSERT_OK(array->ValidateFull());
4146+
ASSERT_EQ(array->null_count(), 2);
4147+
ASSERT_EQ(array->length(), 9);
4148+
auto comp = ArrayFromJSON(float16(), "[1.5,1.5,1.5,1.5,1.5,null,2,null,2]");
4149+
AssertArraysEqual(*array, *comp);
4150+
}
4151+
4152+
std::vector<Float16> vals = {Float16(1.0f), Float16(2.0f), Float16(3.0f)};
4153+
std::vector<bool> is_valid = {true, false, true};
4154+
std::vector<uint8_t> valid_bytes = {1, 0, 1};
4155+
std::vector<uint8_t> bitmap = {0b00000101};
4156+
ASSERT_OK(builder.AppendValues(vals));
4157+
ASSERT_OK(builder.AppendValues(vals, is_valid));
4158+
ASSERT_OK(builder.AppendValues(vals.data(), vals.size(), is_valid));
4159+
ASSERT_OK(builder.AppendValues(vals.data(), vals.size()));
4160+
ASSERT_OK(builder.AppendValues(vals.data(), vals.size(), valid_bytes.data()));
4161+
ASSERT_OK(builder.AppendValues(vals.data(), vals.size(), bitmap.data(), 0));
4162+
4163+
{
4164+
ASSERT_OK_AND_ASSIGN(auto array, builder.Finish());
4165+
ASSERT_OK(array->ValidateFull());
4166+
ASSERT_EQ(array->null_count(), 4);
4167+
ASSERT_EQ(array->length(), 18);
4168+
auto comp =
4169+
ArrayFromJSON(float16(), "[1,2,3,1,null,3,1,null,3,1,2,3,1,null,3,1,null,3]");
4170+
AssertArraysEqual(*array, *comp);
4171+
}
4172+
}
4173+
41024174
} // namespace arrow

cpp/src/arrow/array/builder_primitive.h

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "arrow/result.h"
2727
#include "arrow/type.h"
2828
#include "arrow/type_traits.h"
29+
#include "arrow/util/float16.h"
2930

3031
namespace arrow {
3132

@@ -364,7 +365,6 @@ using Int16Builder = NumericBuilder<Int16Type>;
364365
using Int32Builder = NumericBuilder<Int32Type>;
365366
using Int64Builder = NumericBuilder<Int64Type>;
366367

367-
using HalfFloatBuilder = NumericBuilder<HalfFloatType>;
368368
using FloatBuilder = NumericBuilder<FloatType>;
369369
using DoubleBuilder = NumericBuilder<DoubleType>;
370370

@@ -384,6 +384,107 @@ using DurationBuilder = NumericBuilder<DurationType>;
384384

385385
/// @}
386386

387+
/// \addtogroup numeric-builders
388+
///
389+
/// @{
390+
391+
class ARROW_EXPORT HalfFloatBuilder : public NumericBuilder<HalfFloatType> {
392+
public:
393+
using BaseClass = NumericBuilder<HalfFloatType>;
394+
using Float16 = arrow::util::Float16;
395+
396+
using BaseClass::Append;
397+
using BaseClass::AppendValues;
398+
using BaseClass::BaseClass;
399+
using BaseClass::GetValue;
400+
using BaseClass::UnsafeAppend;
401+
402+
/// Scalar append a arrow::util::Float16
403+
Status Append(const Float16 val) { return Append(val.bits()); }
404+
405+
/// Scalar append a arrow::util::Float16, without checking for capacity
406+
void UnsafeAppend(const Float16 val) { UnsafeAppend(val.bits()); }
407+
408+
/// \brief Append a sequence of elements in one shot
409+
/// \param[in] values a contiguous array of arrow::util::Float16
410+
/// \param[in] length the number of values to append
411+
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
412+
/// indicates a valid (non-null) value
413+
/// \return Status
414+
Status AppendValues(const Float16* values, int64_t length,
415+
const uint8_t* valid_bytes = NULLPTR) {
416+
return BaseClass::AppendValues(reinterpret_cast<const uint16_t*>(values), length,
417+
valid_bytes);
418+
}
419+
420+
/// \brief Append a sequence of elements in one shot
421+
/// \param[in] values a contiguous array of arrow::util::Float16
422+
/// \param[in] length the number of values to append
423+
/// \param[in] bitmap a validity bitmap to copy (may be null)
424+
/// \param[in] bitmap_offset an offset into the validity bitmap
425+
/// \return Status
426+
Status AppendValues(const Float16* values, int64_t length, const uint8_t* bitmap,
427+
int64_t bitmap_offset) {
428+
return BaseClass::AppendValues(reinterpret_cast<const uint16_t*>(values), length,
429+
bitmap, bitmap_offset);
430+
}
431+
432+
/// \brief Append a sequence of elements in one shot
433+
/// \param[in] values a contiguous array of arrow::util::Float16
434+
/// \param[in] length the number of values to append
435+
/// \param[in] is_valid a std::vector<bool> indicating valid (1) or null
436+
/// (0). Equal in length to values
437+
/// \return Status
438+
Status AppendValues(const Float16* values, int64_t length,
439+
const std::vector<bool>& is_valid) {
440+
return BaseClass::AppendValues(reinterpret_cast<const uint16_t*>(values), length,
441+
is_valid);
442+
}
443+
444+
/// \brief Append a sequence of elements in one shot
445+
/// \param[in] values a std::vector<arrow::util::Float16>
446+
/// \param[in] is_valid a std::vector<bool> indicating valid (1) or null
447+
/// (0). Equal in length to values
448+
/// \return Status
449+
Status AppendValues(const std::vector<Float16>& values,
450+
const std::vector<bool>& is_valid) {
451+
return AppendValues(values.data(), static_cast<int64_t>(values.size()), is_valid);
452+
}
453+
454+
/// \brief Append a sequence of elements in one shot
455+
/// \param[in] values a std::vector<arrow::util::Float16>
456+
/// \return Status
457+
Status AppendValues(const std::vector<Float16>& values) {
458+
return AppendValues(values.data(), static_cast<int64_t>(values.size()));
459+
}
460+
461+
/// \brief Append one value many times in one shot
462+
/// \param[in] length the number of values to append
463+
/// \param[in] value a arrow::util::Float16
464+
Status AppendValues(int64_t length, Float16 value) {
465+
RETURN_NOT_OK(Reserve(length));
466+
data_builder_.UnsafeAppend(length, value.bits());
467+
ArrayBuilder::UnsafeSetNotNull(length);
468+
return Status::OK();
469+
}
470+
471+
/// \brief Get the value at a certain index
472+
/// \param[in] index the zero-based index
473+
/// @tparam T arrow::util::Float16 or value_type (uint16_t)
474+
template <typename T = BaseClass::value_type>
475+
T GetValue(int64_t index) const {
476+
static_assert(std::is_same_v<T, BaseClass::value_type> ||
477+
std::is_same_v<T, arrow::util::Float16>);
478+
if constexpr (std::is_same_v<T, BaseClass::value_type>) {
479+
return BaseClass::GetValue(index);
480+
} else {
481+
return Float16::FromBits(BaseClass::GetValue(index));
482+
}
483+
}
484+
};
485+
486+
/// @}
487+
387488
class ARROW_EXPORT BooleanBuilder
388489
: public ArrayBuilder,
389490
public internal::ArrayBuilderExtraOps<BooleanBuilder, bool> {

cpp/src/arrow/type_fwd.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,17 @@ _NUMERIC_TYPE_DECL(UInt8)
242242
_NUMERIC_TYPE_DECL(UInt16)
243243
_NUMERIC_TYPE_DECL(UInt32)
244244
_NUMERIC_TYPE_DECL(UInt64)
245-
_NUMERIC_TYPE_DECL(HalfFloat)
246245
_NUMERIC_TYPE_DECL(Float)
247246
_NUMERIC_TYPE_DECL(Double)
248247

249248
#undef _NUMERIC_TYPE_DECL
250249

250+
class HalfFloatType;
251+
using HalfFloatArray = NumericArray<HalfFloatType>;
252+
class HalfFloatBuilder;
253+
struct HalfFloatScalar;
254+
using HalfFloatTensor = NumericTensor<HalfFloatType>;
255+
251256
enum class DateUnit : char { DAY = 0, MILLI = 1 };
252257

253258
class DateType;

cpp/src/arrow/util/float16.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class ARROW_EXPORT Float16 {
178178
}
179179
};
180180

181+
static_assert(std::is_standard_layout_v<Float16>);
181182
static_assert(std::is_trivial_v<Float16>);
182183
static_assert(sizeof(Float16) == sizeof(uint16_t));
183184

0 commit comments

Comments
 (0)