@@ -23,20 +23,20 @@ struct EllpackDeviceAccessor {
23
23
bool is_dense;
24
24
/* ! \brief Row length for ELLPACK, equal to number of features. */
25
25
size_t row_stride;
26
- size_t base_rowid{};
27
- size_t n_rows{};
28
- common::CompressedIterator<uint32_t > gidx_iter;
26
+ bst_idx_t base_rowid{0 };
27
+ bst_idx_t n_rows{0 };
28
+ common::CompressedIterator<std:: uint32_t > gidx_iter;
29
29
/* ! \brief Minimum value for each feature. Size equals to number of features. */
30
- common::Span<const bst_float > min_fvalue;
30
+ common::Span<const float > min_fvalue;
31
31
/* ! \brief Histogram cut pointers. Size equals to (number of features + 1). */
32
- common::Span<const uint32_t > feature_segments;
32
+ common::Span<const std:: uint32_t > feature_segments;
33
33
/* ! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
34
- common::Span<const bst_float > gidx_fvalue_map;
34
+ common::Span<const float > gidx_fvalue_map;
35
35
36
36
common::Span<const FeatureType> feature_types;
37
37
38
- EllpackDeviceAccessor (DeviceOrd device, const common::HistogramCuts& cuts, bool is_dense ,
39
- size_t row_stride, size_t base_rowid, size_t n_rows,
38
+ EllpackDeviceAccessor (DeviceOrd device, std::shared_ptr< const common::HistogramCuts> cuts,
39
+ bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows,
40
40
common::CompressedIterator<uint32_t > gidx_iter,
41
41
common::Span<FeatureType const > feature_types)
42
42
: is_dense(is_dense),
@@ -46,16 +46,16 @@ struct EllpackDeviceAccessor {
46
46
gidx_iter(gidx_iter),
47
47
feature_types{feature_types} {
48
48
if (device.IsCPU ()) {
49
- gidx_fvalue_map = cuts. cut_values_ .ConstHostSpan ();
50
- feature_segments = cuts. cut_ptrs_ .ConstHostSpan ();
51
- min_fvalue = cuts. min_vals_ .ConstHostSpan ();
49
+ gidx_fvalue_map = cuts-> cut_values_ .ConstHostSpan ();
50
+ feature_segments = cuts-> cut_ptrs_ .ConstHostSpan ();
51
+ min_fvalue = cuts-> min_vals_ .ConstHostSpan ();
52
52
} else {
53
- cuts. cut_values_ .SetDevice (device);
54
- cuts. cut_ptrs_ .SetDevice (device);
55
- cuts. min_vals_ .SetDevice (device);
56
- gidx_fvalue_map = cuts. cut_values_ .ConstDeviceSpan ();
57
- feature_segments = cuts. cut_ptrs_ .ConstDeviceSpan ();
58
- min_fvalue = cuts. min_vals_ .ConstDeviceSpan ();
53
+ cuts-> cut_values_ .SetDevice (device);
54
+ cuts-> cut_ptrs_ .SetDevice (device);
55
+ cuts-> min_vals_ .SetDevice (device);
56
+ gidx_fvalue_map = cuts-> cut_values_ .ConstDeviceSpan ();
57
+ feature_segments = cuts-> cut_ptrs_ .ConstDeviceSpan ();
58
+ min_fvalue = cuts-> min_vals_ .ConstDeviceSpan ();
59
59
}
60
60
}
61
61
// Get a matrix element, uses binary search for look up Return NaN if missing
@@ -142,13 +142,14 @@ class EllpackPageImpl {
142
142
* This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo
143
143
* and the given number of rows.
144
144
*/
145
- EllpackPageImpl (DeviceOrd device, common::HistogramCuts cuts, bool is_dense, size_t row_stride ,
146
- size_t n_rows);
145
+ EllpackPageImpl (DeviceOrd device, std::shared_ptr< common::HistogramCuts const > cuts ,
146
+ bool is_dense, size_t row_stride, size_t n_rows);
147
147
/* !
148
148
* \brief Constructor used for external memory.
149
149
*/
150
- EllpackPageImpl (DeviceOrd device, common::HistogramCuts cuts, const SparsePage& page,
151
- bool is_dense, size_t row_stride, common::Span<FeatureType const > feature_types);
150
+ EllpackPageImpl (DeviceOrd device, std::shared_ptr<common::HistogramCuts const > cuts,
151
+ const SparsePage& page, bool is_dense, size_t row_stride,
152
+ common::Span<FeatureType const > feature_types);
152
153
153
154
/* !
154
155
* \brief Constructor from an existing DMatrix.
@@ -162,7 +163,7 @@ class EllpackPageImpl {
162
163
explicit EllpackPageImpl (AdapterBatch batch, float missing, DeviceOrd device, bool is_dense,
163
164
common::Span<size_t > row_counts_span,
164
165
common::Span<FeatureType const > feature_types, size_t row_stride,
165
- size_t n_rows, common::HistogramCuts const & cuts);
166
+ size_t n_rows, std::shared_ptr< common::HistogramCuts const > cuts);
166
167
/* *
167
168
* \brief Constructor from an existing CPU gradient index.
168
169
*/
@@ -194,16 +195,17 @@ class EllpackPageImpl {
194
195
base_rowid = row_id;
195
196
}
196
197
197
- [[nodiscard]] common::HistogramCuts& Cuts () { return cuts_; }
198
- [[nodiscard]] common::HistogramCuts const & Cuts () const { return cuts_; }
198
+ [[nodiscard]] common::HistogramCuts const & Cuts () const { return *cuts_; }
199
+ [[nodiscard]] std::shared_ptr<common::HistogramCuts const > CutsShared () const { return cuts_; }
200
+ void SetCuts (std::shared_ptr<common::HistogramCuts const > cuts) { cuts_ = cuts; }
199
201
200
202
/* ! \return Estimation of memory cost of this page. */
201
203
static size_t MemCostBytes (size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
202
204
203
205
204
206
/* ! \brief Return the total number of symbols (total number of bins plus 1 for
205
207
* not found). */
206
- [[nodiscard]] std::size_t NumSymbols () const { return cuts_. TotalBins () + 1 ; }
208
+ [[nodiscard]] std::size_t NumSymbols () const { return cuts_-> TotalBins () + 1 ; }
207
209
208
210
[[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor (
209
211
DeviceOrd device, common::Span<FeatureType const > feature_types = {}) const ;
@@ -225,19 +227,18 @@ class EllpackPageImpl {
225
227
*/
226
228
void InitCompressedData (DeviceOrd device);
227
229
228
-
229
- public:
230
+ public:
230
231
/* ! \brief Whether or not if the matrix is dense. */
231
232
bool is_dense;
232
233
/* ! \brief Row length for ELLPACK. */
233
234
size_t row_stride;
234
- size_t base_rowid{0 };
235
- size_t n_rows{};
235
+ bst_idx_t base_rowid{0 };
236
+ bst_idx_t n_rows{};
236
237
/* ! \brief global index of histogram, which is stored in ELLPACK format. */
237
238
HostDeviceVector<common::CompressedByteT> gidx_buffer;
238
239
239
240
private:
240
- common::HistogramCuts cuts_;
241
+ std::shared_ptr< common::HistogramCuts const > cuts_;
241
242
common::Monitor monitor_;
242
243
};
243
244
0 commit comments