1
- /* !
2
- * Copyright 2017-2022 by Contributors
1
+ /* *
2
+ * Copyright 2017-2024, XGBoost Contributors
3
3
* \file row_set.h
4
4
* \brief Quick Utility to compute subset of rows
5
5
* \author Philip Cho, Tianqi Chen
6
6
*/
7
7
#ifndef XGBOOST_COMMON_ROW_SET_H_
8
8
#define XGBOOST_COMMON_ROW_SET_H_
9
9
10
- #include < xgboost/data.h>
11
- #include < algorithm>
12
- #include < vector>
13
- #include < utility>
14
- #include < memory>
10
+ #include < cstddef> // for size_t
11
+ #include < iterator> // for distance
12
+ #include < vector> // for vector
15
13
16
- namespace xgboost {
17
- namespace common {
18
- /* ! \brief collection of rowset */
14
+ #include " xgboost/base.h" // for bst_node_t
15
+ #include " xgboost/logging.h" // for CHECK
16
+
17
+ namespace xgboost ::common {
18
+ /* *
19
+ * @brief Collection of rows for each tree node.
20
+ */
19
21
class RowSetCollection {
20
22
public:
21
23
RowSetCollection () = default ;
@@ -24,110 +26,103 @@ class RowSetCollection {
24
26
RowSetCollection& operator =(RowSetCollection const &) = delete ;
25
27
RowSetCollection& operator =(RowSetCollection&&) = default ;
26
28
27
- /* ! \brief data structure to store an instance set, a subset of
28
- * rows (instances) associated with a particular node in a decision
29
- * tree. */
29
+ /* *
30
+ * @brief data structure to store an instance set, a subset of rows (instances)
31
+ * associated with a particular node in a decision tree.
32
+ */
30
33
struct Elem {
31
- const size_t * begin{nullptr };
32
- const size_t * end{nullptr };
34
+ std:: size_t const * begin{nullptr };
35
+ std:: size_t const * end{nullptr };
33
36
bst_node_t node_id{-1 };
34
- // id of node associated with this instance set; -1 means uninitialized
35
- Elem ()
36
- = default ;
37
- Elem (const size_t * begin,
38
- const size_t * end,
39
- bst_node_t node_id = -1 )
37
+ // id of node associated with this instance set; -1 means uninitialized
38
+ Elem () = default ;
39
+ Elem (std::size_t const * begin, std::size_t const * end, bst_node_t node_id = -1 )
40
40
: begin(begin), end(end), node_id(node_id) {}
41
41
42
- inline size_t Size () const {
43
- return end - begin;
44
- }
42
+ std::size_t Size () const { return end - begin; }
45
43
};
46
44
47
- std::vector<Elem>::const_iterator begin () const { // NOLINT
48
- return elem_of_each_node_.begin ();
45
+ [[nodiscard]] std::vector<Elem>::const_iterator begin () const { // NOLINT
46
+ return elem_of_each_node_.cbegin ();
49
47
}
50
-
51
- std::vector<Elem>::const_iterator end () const { // NOLINT
52
- return elem_of_each_node_.end ();
48
+ [[nodiscard]] std::vector<Elem>::const_iterator end () const { // NOLINT
49
+ return elem_of_each_node_.cend ();
53
50
}
54
51
55
- size_t Size () const { return std::distance (begin (), end ()); }
52
+ [[nodiscard]] std:: size_t Size () const { return std::distance (begin (), end ()); }
56
53
57
- /* ! \ brief return corresponding element set given the node_id */
58
- inline const Elem& operator [](unsigned node_id) const {
59
- const Elem& e = elem_of_each_node_[node_id];
54
+ /* * @ brief return corresponding element set given the node_id */
55
+ [[nodiscard]] Elem const & operator [](bst_node_t node_id) const {
56
+ Elem const & e = elem_of_each_node_[node_id];
60
57
return e;
61
58
}
62
-
63
- /* ! \brief return corresponding element set given the node_id */
64
- inline Elem& operator [](unsigned node_id) {
59
+ /* * @brief return corresponding element set given the node_id */
60
+ [[nodiscard]] Elem& operator [](bst_node_t node_id) {
65
61
Elem& e = elem_of_each_node_[node_id];
66
62
return e;
67
63
}
68
64
69
65
// clear up things
70
- inline void Clear () {
66
+ void Clear () {
71
67
elem_of_each_node_.clear ();
72
68
}
73
69
// initialize node id 0->everything
74
- inline void Init () {
75
- CHECK_EQ (elem_of_each_node_.size (), 0U );
70
+ void Init () {
71
+ CHECK (elem_of_each_node_.empty () );
76
72
77
73
if (row_indices_.empty ()) { // edge case: empty instance set
78
- constexpr size_t * kBegin = nullptr ;
79
- constexpr size_t * kEnd = nullptr ;
74
+ constexpr std:: size_t * kBegin = nullptr ;
75
+ constexpr std:: size_t * kEnd = nullptr ;
80
76
static_assert (kEnd - kBegin == 0 );
81
77
elem_of_each_node_.emplace_back (kBegin , kEnd , 0 );
82
78
return ;
83
79
}
84
80
85
- const size_t * begin = dmlc::BeginPtr (row_indices_);
86
- const size_t * end = dmlc::BeginPtr (row_indices_) + row_indices_.size ();
81
+ const std:: size_t * begin = dmlc::BeginPtr (row_indices_);
82
+ const std:: size_t * end = dmlc::BeginPtr (row_indices_) + row_indices_.size ();
87
83
elem_of_each_node_.emplace_back (begin, end, 0 );
88
84
}
89
85
90
- std::vector<size_t >* Data () { return &row_indices_; }
91
- std::vector<size_t > const * Data () const { return &row_indices_; }
86
+ [[nodiscard]] std::vector<std:: size_t >* Data () { return &row_indices_; }
87
+ [[nodiscard]] std::vector<std:: size_t > const * Data () const { return &row_indices_; }
92
88
93
89
// split rowset into two
94
- inline void AddSplit (unsigned node_id, unsigned left_node_id, unsigned right_node_id,
95
- size_t n_left, size_t n_right) {
90
+ void AddSplit (bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id,
91
+ bst_idx_t n_left, bst_idx_t n_right) {
96
92
const Elem e = elem_of_each_node_[node_id];
97
93
98
- size_t * all_begin{nullptr };
99
- size_t * begin{nullptr };
94
+ std:: size_t * all_begin{nullptr };
95
+ std:: size_t * begin{nullptr };
100
96
if (e.begin == nullptr ) {
101
97
CHECK_EQ (n_left, 0 );
102
98
CHECK_EQ (n_right, 0 );
103
99
} else {
104
- all_begin = dmlc::BeginPtr ( row_indices_);
100
+ all_begin = row_indices_. data ( );
105
101
begin = all_begin + (e.begin - all_begin);
106
102
}
107
103
108
104
CHECK_EQ (n_left + n_right, e.Size ());
109
105
CHECK_LE (begin + n_left, e.end );
110
106
CHECK_EQ (begin + n_left + n_right, e.end );
111
107
112
- if (left_node_id >= elem_of_each_node_.size ()) {
113
- elem_of_each_node_.resize (left_node_id + 1 , Elem ( nullptr , nullptr , -1 ) );
108
+ if (left_node_id >= static_cast < bst_node_t >( elem_of_each_node_.size () )) {
109
+ elem_of_each_node_.resize (left_node_id + 1 , Elem{ nullptr , nullptr , -1 } );
114
110
}
115
- if (right_node_id >= elem_of_each_node_.size ()) {
116
- elem_of_each_node_.resize (right_node_id + 1 , Elem ( nullptr , nullptr , -1 ) );
111
+ if (right_node_id >= static_cast < bst_node_t >( elem_of_each_node_.size () )) {
112
+ elem_of_each_node_.resize (right_node_id + 1 , Elem{ nullptr , nullptr , -1 } );
117
113
}
118
114
119
- elem_of_each_node_[left_node_id] = Elem ( begin, begin + n_left, left_node_id) ;
120
- elem_of_each_node_[right_node_id] = Elem ( begin + n_left, e.end , right_node_id) ;
121
- elem_of_each_node_[node_id] = Elem ( nullptr , nullptr , -1 ) ;
115
+ elem_of_each_node_[left_node_id] = Elem{ begin, begin + n_left, left_node_id} ;
116
+ elem_of_each_node_[right_node_id] = Elem{ begin + n_left, e.end , right_node_id} ;
117
+ elem_of_each_node_[node_id] = Elem{ nullptr , nullptr , -1 } ;
122
118
}
123
119
124
120
private:
125
121
// stores the row indexes in the set
126
- std::vector<size_t > row_indices_;
122
+ std::vector<std:: size_t > row_indices_;
127
123
// vector: node_id -> elements
128
124
std::vector<Elem> elem_of_each_node_;
129
125
};
130
- } // namespace common
131
- } // namespace xgboost
126
+ } // namespace xgboost::common
132
127
133
128
#endif // XGBOOST_COMMON_ROW_SET_H_
0 commit comments