Skip to content

Commit 2b400b1

Browse files
authored
Small cleanup for rowset collection. (dmlc#10401)
1 parent e5f1720 commit 2b400b1

File tree

1 file changed

+55
-60
lines changed

1 file changed

+55
-60
lines changed

src/common/row_set.h

+55-60
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
/*!
2-
* Copyright 2017-2022 by Contributors
1+
/**
2+
* Copyright 2017-2024, XGBoost Contributors
33
* \file row_set.h
44
* \brief Quick Utility to compute subset of rows
55
* \author Philip Cho, Tianqi Chen
66
*/
77
#ifndef XGBOOST_COMMON_ROW_SET_H_
88
#define XGBOOST_COMMON_ROW_SET_H_
99

10-
#include <xgboost/data.h>
11-
#include <algorithm>
12-
#include <vector>
13-
#include <utility>
14-
#include <memory>
10+
#include <cstddef> // for size_t
11+
#include <iterator> // for distance
12+
#include <vector> // for vector
1513

16-
namespace xgboost {
17-
namespace common {
18-
/*! \brief collection of rowset */
14+
#include "xgboost/base.h" // for bst_node_t
15+
#include "xgboost/logging.h" // for CHECK
16+
17+
namespace xgboost::common {
18+
/**
19+
* @brief Collection of rows for each tree node.
20+
*/
1921
class RowSetCollection {
2022
public:
2123
RowSetCollection() = default;
@@ -24,110 +26,103 @@ class RowSetCollection {
2426
RowSetCollection& operator=(RowSetCollection const&) = delete;
2527
RowSetCollection& operator=(RowSetCollection&&) = default;
2628

27-
/*! \brief data structure to store an instance set, a subset of
28-
* rows (instances) associated with a particular node in a decision
29-
* tree. */
29+
/**
30+
* @brief data structure to store an instance set, a subset of rows (instances)
31+
* associated with a particular node in a decision tree.
32+
*/
3033
struct Elem {
31-
const size_t* begin{nullptr};
32-
const size_t* end{nullptr};
34+
std::size_t const* begin{nullptr};
35+
std::size_t const* end{nullptr};
3336
bst_node_t node_id{-1};
34-
// id of node associated with this instance set; -1 means uninitialized
35-
Elem()
36-
= default;
37-
Elem(const size_t* begin,
38-
const size_t* end,
39-
bst_node_t node_id = -1)
37+
// id of node associated with this instance set; -1 means uninitialized
38+
Elem() = default;
39+
Elem(std::size_t const* begin, std::size_t const* end, bst_node_t node_id = -1)
4040
: begin(begin), end(end), node_id(node_id) {}
4141

42-
inline size_t Size() const {
43-
return end - begin;
44-
}
42+
std::size_t Size() const { return end - begin; }
4543
};
4644

47-
std::vector<Elem>::const_iterator begin() const { // NOLINT
48-
return elem_of_each_node_.begin();
45+
[[nodiscard]] std::vector<Elem>::const_iterator begin() const { // NOLINT
46+
return elem_of_each_node_.cbegin();
4947
}
50-
51-
std::vector<Elem>::const_iterator end() const { // NOLINT
52-
return elem_of_each_node_.end();
48+
[[nodiscard]] std::vector<Elem>::const_iterator end() const { // NOLINT
49+
return elem_of_each_node_.cend();
5350
}
5451

55-
size_t Size() const { return std::distance(begin(), end()); }
52+
[[nodiscard]] std::size_t Size() const { return std::distance(begin(), end()); }
5653

57-
/*! \brief return corresponding element set given the node_id */
58-
inline const Elem& operator[](unsigned node_id) const {
59-
const Elem& e = elem_of_each_node_[node_id];
54+
/** @brief return corresponding element set given the node_id */
55+
[[nodiscard]] Elem const& operator[](bst_node_t node_id) const {
56+
Elem const& e = elem_of_each_node_[node_id];
6057
return e;
6158
}
62-
63-
/*! \brief return corresponding element set given the node_id */
64-
inline Elem& operator[](unsigned node_id) {
59+
/** @brief return corresponding element set given the node_id */
60+
[[nodiscard]] Elem& operator[](bst_node_t node_id) {
6561
Elem& e = elem_of_each_node_[node_id];
6662
return e;
6763
}
6864

6965
// clear up things
70-
inline void Clear() {
66+
void Clear() {
7167
elem_of_each_node_.clear();
7268
}
7369
// initialize node id 0->everything
74-
inline void Init() {
75-
CHECK_EQ(elem_of_each_node_.size(), 0U);
70+
void Init() {
71+
CHECK(elem_of_each_node_.empty());
7672

7773
if (row_indices_.empty()) { // edge case: empty instance set
78-
constexpr size_t* kBegin = nullptr;
79-
constexpr size_t* kEnd = nullptr;
74+
constexpr std::size_t* kBegin = nullptr;
75+
constexpr std::size_t* kEnd = nullptr;
8076
static_assert(kEnd - kBegin == 0);
8177
elem_of_each_node_.emplace_back(kBegin, kEnd, 0);
8278
return;
8379
}
8480

85-
const size_t* begin = dmlc::BeginPtr(row_indices_);
86-
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
81+
const std::size_t* begin = dmlc::BeginPtr(row_indices_);
82+
const std::size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
8783
elem_of_each_node_.emplace_back(begin, end, 0);
8884
}
8985

90-
std::vector<size_t>* Data() { return &row_indices_; }
91-
std::vector<size_t> const* Data() const { return &row_indices_; }
86+
[[nodiscard]] std::vector<std::size_t>* Data() { return &row_indices_; }
87+
[[nodiscard]] std::vector<std::size_t> const* Data() const { return &row_indices_; }
9288

9389
// split rowset into two
94-
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
95-
size_t n_left, size_t n_right) {
90+
void AddSplit(bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id,
91+
bst_idx_t n_left, bst_idx_t n_right) {
9692
const Elem e = elem_of_each_node_[node_id];
9793

98-
size_t* all_begin{nullptr};
99-
size_t* begin{nullptr};
94+
std::size_t* all_begin{nullptr};
95+
std::size_t* begin{nullptr};
10096
if (e.begin == nullptr) {
10197
CHECK_EQ(n_left, 0);
10298
CHECK_EQ(n_right, 0);
10399
} else {
104-
all_begin = dmlc::BeginPtr(row_indices_);
100+
all_begin = row_indices_.data();
105101
begin = all_begin + (e.begin - all_begin);
106102
}
107103

108104
CHECK_EQ(n_left + n_right, e.Size());
109105
CHECK_LE(begin + n_left, e.end);
110106
CHECK_EQ(begin + n_left + n_right, e.end);
111107

112-
if (left_node_id >= elem_of_each_node_.size()) {
113-
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
108+
if (left_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
109+
elem_of_each_node_.resize(left_node_id + 1, Elem{nullptr, nullptr, -1});
114110
}
115-
if (right_node_id >= elem_of_each_node_.size()) {
116-
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
111+
if (right_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
112+
elem_of_each_node_.resize(right_node_id + 1, Elem{nullptr, nullptr, -1});
117113
}
118114

119-
elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
120-
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
121-
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
115+
elem_of_each_node_[left_node_id] = Elem{begin, begin + n_left, left_node_id};
116+
elem_of_each_node_[right_node_id] = Elem{begin + n_left, e.end, right_node_id};
117+
elem_of_each_node_[node_id] = Elem{nullptr, nullptr, -1};
122118
}
123119

124120
private:
125121
// stores the row indexes in the set
126-
std::vector<size_t> row_indices_;
122+
std::vector<std::size_t> row_indices_;
127123
// vector: node_id -> elements
128124
std::vector<Elem> elem_of_each_node_;
129125
};
130-
} // namespace common
131-
} // namespace xgboost
126+
} // namespace xgboost::common
132127

133128
#endif // XGBOOST_COMMON_ROW_SET_H_

0 commit comments

Comments
 (0)