1- /* !
2- * Copyright 2017-2022 by Contributors
1+ /* *
2+ * Copyright 2017-2024, XGBoost Contributors
33 * \file row_set.h
44 * \brief Quick Utility to compute subset of rows
55 * \author Philip Cho, Tianqi Chen
66 */
77#ifndef XGBOOST_COMMON_ROW_SET_H_
88#define XGBOOST_COMMON_ROW_SET_H_
99
10- #include < xgboost/data.h>
11- #include < algorithm>
12- #include < vector>
13- #include < utility>
14- #include < memory>
10+ #include < cstddef> // for size_t
11+ #include < iterator> // for distance
12+ #include < vector> // for vector
1513
16- namespace xgboost {
17- namespace common {
18- /* ! \brief collection of rowset */
14+ #include " xgboost/base.h" // for bst_node_t
15+ #include " xgboost/logging.h" // for CHECK
16+
17+ namespace xgboost ::common {
18+ /* *
19+ * @brief Collection of rows for each tree node.
20+ */
1921class RowSetCollection {
2022 public:
2123 RowSetCollection () = default ;
@@ -24,110 +26,103 @@ class RowSetCollection {
2426 RowSetCollection& operator =(RowSetCollection const &) = delete ;
2527 RowSetCollection& operator =(RowSetCollection&&) = default ;
2628
27- /* ! \brief data structure to store an instance set, a subset of
28- * rows (instances) associated with a particular node in a decision
29- * tree. */
29+ /* *
30+ * @brief data structure to store an instance set, a subset of rows (instances)
31+ * associated with a particular node in a decision tree.
32+ */
3033 struct Elem {
31- const size_t * begin{nullptr };
32- const size_t * end{nullptr };
34+ std:: size_t const * begin{nullptr };
35+ std:: size_t const * end{nullptr };
3336 bst_node_t node_id{-1 };
34- // id of node associated with this instance set; -1 means uninitialized
35- Elem ()
36- = default ;
37- Elem (const size_t * begin,
38- const size_t * end,
39- bst_node_t node_id = -1 )
37+ // id of node associated with this instance set; -1 means uninitialized
38+ Elem () = default ;
39+ Elem (std::size_t const * begin, std::size_t const * end, bst_node_t node_id = -1 )
4040 : begin(begin), end(end), node_id(node_id) {}
4141
42- inline size_t Size () const {
43- return end - begin;
44- }
42+ std::size_t Size () const { return end - begin; }
4543 };
4644
47- std::vector<Elem>::const_iterator begin () const { // NOLINT
48- return elem_of_each_node_.begin ();
45+ [[nodiscard]] std::vector<Elem>::const_iterator begin () const { // NOLINT
46+ return elem_of_each_node_.cbegin ();
4947 }
50-
51- std::vector<Elem>::const_iterator end () const { // NOLINT
52- return elem_of_each_node_.end ();
48+ [[nodiscard]] std::vector<Elem>::const_iterator end () const { // NOLINT
49+ return elem_of_each_node_.cend ();
5350 }
5451
55- size_t Size () const { return std::distance (begin (), end ()); }
52+ [[nodiscard]] std:: size_t Size () const { return std::distance (begin (), end ()); }
5653
57- /* ! \ brief return corresponding element set given the node_id */
58- inline const Elem& operator [](unsigned node_id) const {
59- const Elem& e = elem_of_each_node_[node_id];
54+ /* * @ brief return corresponding element set given the node_id */
55+ [[nodiscard]] Elem const & operator [](bst_node_t node_id) const {
56+ Elem const & e = elem_of_each_node_[node_id];
6057 return e;
6158 }
62-
63- /* ! \brief return corresponding element set given the node_id */
64- inline Elem& operator [](unsigned node_id) {
59+ /* * @brief return corresponding element set given the node_id */
60+ [[nodiscard]] Elem& operator [](bst_node_t node_id) {
6561 Elem& e = elem_of_each_node_[node_id];
6662 return e;
6763 }
6864
6965 // clear up things
70- inline void Clear () {
66+ void Clear () {
7167 elem_of_each_node_.clear ();
7268 }
7369 // initialize node id 0->everything
74- inline void Init () {
75- CHECK_EQ (elem_of_each_node_.size (), 0U );
70+ void Init () {
71+ CHECK (elem_of_each_node_.empty () );
7672
7773 if (row_indices_.empty ()) { // edge case: empty instance set
78- constexpr size_t * kBegin = nullptr ;
79- constexpr size_t * kEnd = nullptr ;
74+ constexpr std:: size_t * kBegin = nullptr ;
75+ constexpr std:: size_t * kEnd = nullptr ;
8076 static_assert (kEnd - kBegin == 0 );
8177 elem_of_each_node_.emplace_back (kBegin , kEnd , 0 );
8278 return ;
8379 }
8480
85- const size_t * begin = dmlc::BeginPtr (row_indices_);
86- const size_t * end = dmlc::BeginPtr (row_indices_) + row_indices_.size ();
81+ const std:: size_t * begin = dmlc::BeginPtr (row_indices_);
82+ const std:: size_t * end = dmlc::BeginPtr (row_indices_) + row_indices_.size ();
8783 elem_of_each_node_.emplace_back (begin, end, 0 );
8884 }
8985
90- std::vector<size_t >* Data () { return &row_indices_; }
91- std::vector<size_t > const * Data () const { return &row_indices_; }
86+ [[nodiscard]] std::vector<std:: size_t >* Data () { return &row_indices_; }
87+ [[nodiscard]] std::vector<std:: size_t > const * Data () const { return &row_indices_; }
9288
9389 // split rowset into two
94- inline void AddSplit (unsigned node_id, unsigned left_node_id, unsigned right_node_id,
95- size_t n_left, size_t n_right) {
90+ void AddSplit (bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id,
91+ bst_idx_t n_left, bst_idx_t n_right) {
9692 const Elem e = elem_of_each_node_[node_id];
9793
98- size_t * all_begin{nullptr };
99- size_t * begin{nullptr };
94+ std:: size_t * all_begin{nullptr };
95+ std:: size_t * begin{nullptr };
10096 if (e.begin == nullptr ) {
10197 CHECK_EQ (n_left, 0 );
10298 CHECK_EQ (n_right, 0 );
10399 } else {
104- all_begin = dmlc::BeginPtr ( row_indices_);
100+ all_begin = row_indices_. data ( );
105101 begin = all_begin + (e.begin - all_begin);
106102 }
107103
108104 CHECK_EQ (n_left + n_right, e.Size ());
109105 CHECK_LE (begin + n_left, e.end );
110106 CHECK_EQ (begin + n_left + n_right, e.end );
111107
112- if (left_node_id >= elem_of_each_node_.size ()) {
113- elem_of_each_node_.resize (left_node_id + 1 , Elem ( nullptr , nullptr , -1 ) );
108+ if (left_node_id >= static_cast < bst_node_t >( elem_of_each_node_.size () )) {
109+ elem_of_each_node_.resize (left_node_id + 1 , Elem{ nullptr , nullptr , -1 } );
114110 }
115- if (right_node_id >= elem_of_each_node_.size ()) {
116- elem_of_each_node_.resize (right_node_id + 1 , Elem ( nullptr , nullptr , -1 ) );
111+ if (right_node_id >= static_cast < bst_node_t >( elem_of_each_node_.size () )) {
112+ elem_of_each_node_.resize (right_node_id + 1 , Elem{ nullptr , nullptr , -1 } );
117113 }
118114
119- elem_of_each_node_[left_node_id] = Elem ( begin, begin + n_left, left_node_id) ;
120- elem_of_each_node_[right_node_id] = Elem ( begin + n_left, e.end , right_node_id) ;
121- elem_of_each_node_[node_id] = Elem ( nullptr , nullptr , -1 ) ;
115+ elem_of_each_node_[left_node_id] = Elem{ begin, begin + n_left, left_node_id} ;
116+ elem_of_each_node_[right_node_id] = Elem{ begin + n_left, e.end , right_node_id} ;
117+ elem_of_each_node_[node_id] = Elem{ nullptr , nullptr , -1 } ;
122118 }
123119
124120 private:
125121 // stores the row indexes in the set
126- std::vector<size_t > row_indices_;
122+ std::vector<std:: size_t > row_indices_;
127123 // vector: node_id -> elements
128124 std::vector<Elem> elem_of_each_node_;
129125};
130- } // namespace common
131- } // namespace xgboost
126+ } // namespace xgboost::common
132127
133128#endif // XGBOOST_COMMON_ROW_SET_H_
0 commit comments