Skip to content

Commit e5f1720

Browse files
authored
[EM] Avoid writing cut matrix to cache. (dmlc#10444)
1 parent 63418d2 commit e5f1720

31 files changed

+423
-292
lines changed

include/xgboost/data.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ class MetaInfo {
113113
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
114114

115115
MetaInfo Copy() const;
116-
116+
/**
117+
* @brief Whether the matrix is dense.
118+
*/
119+
bool IsDense() const { return num_col_ * num_row_ == num_nonzero_; }
117120
/*!
118121
* \brief Get weight of each instances.
119122
* \param i Instance index.
@@ -538,10 +541,10 @@ class DMatrix {
538541
/*! \brief virtual destructor */
539542
virtual ~DMatrix();
540543

541-
/*! \brief Whether the matrix is dense. */
542-
[[nodiscard]] bool IsDense() const {
543-
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
544-
}
544+
/**
545+
* @brief Whether the matrix is dense.
546+
*/
547+
[[nodiscard]] bool IsDense() const { return this->Info().IsDense(); }
545548

546549
/**
547550
* \brief Load DMatrix from URI.

src/collective/aggregator.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <limits>
1010
#include <string>
1111
#include <utility>
12-
#include <vector>
1312

1413
#include "allreduce.h"
1514
#include "broadcast.h"

src/common/hist_util.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,17 @@ class HistogramCuts {
162162
}
163163
return vals[bin_idx - 1];
164164
}
165+
166+
void SetDevice(DeviceOrd d) const {
167+
this->cut_ptrs_.SetDevice(d);
168+
this->cut_ptrs_.ConstDevicePointer();
169+
170+
this->cut_values_.SetDevice(d);
171+
this->cut_values_.ConstDevicePointer();
172+
173+
this->min_vals_.SetDevice(d);
174+
this->min_vals_.ConstDevicePointer();
175+
}
165176
};
166177

167178
/**

src/data/ellpack_page.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
/**
2-
* Copyright 2019-2023, XGBoost contributors
2+
* Copyright 2019-2024, XGBoost contributors
33
*/
44
#ifndef XGBOOST_USE_CUDA
55

66
#include "ellpack_page.h"
77

88
#include <xgboost/data.h>
99

10+
#include <memory> // for shared_ptr
11+
1012
// dummy implementation of EllpackPage in case CUDA is not used
1113
namespace xgboost {
1214

1315
class EllpackPageImpl {
14-
common::HistogramCuts cuts_;
16+
std::shared_ptr<common::HistogramCuts> cuts_;
1517

1618
public:
17-
[[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; }
18-
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; }
19+
[[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; }
20+
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
1921
};
2022

2123
EllpackPage::EllpackPage() = default;
@@ -40,12 +42,6 @@ size_t EllpackPage::Size() const {
4042
return 0;
4143
}
4244

43-
[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() {
44-
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
45-
"EllpackPage is required";
46-
return impl_->Cuts();
47-
}
48-
4945
[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const {
5046
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
5147
"EllpackPage is required";

src/data/ellpack_page.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
#include "../common/cuda_context.cuh"
1313
#include "../common/hist_util.cuh"
1414
#include "../common/transform_iterator.h" // MakeIndexTransformIter
15-
#include "./ellpack_page.cuh"
16-
#include "device_adapter.cuh" // for NoInfInData
15+
#include "device_adapter.cuh" // for NoInfInData
16+
#include "ellpack_page.cuh"
1717
#include "ellpack_page.h"
1818
#include "gradient_index.h"
1919
#include "xgboost/data.h"
@@ -33,11 +33,6 @@ size_t EllpackPage::Size() const { return impl_->Size(); }
3333

3434
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
3535

36-
[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() {
37-
CHECK(impl_);
38-
return impl_->Cuts();
39-
}
40-
4136
[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const {
4237
CHECK(impl_);
4338
return impl_->Cuts();
@@ -94,7 +89,8 @@ __global__ void CompressBinEllpackKernel(
9489
}
9590

9691
// Construct an ELLPACK matrix with the given number of empty rows.
97-
EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, bool is_dense,
92+
EllpackPageImpl::EllpackPageImpl(DeviceOrd device,
93+
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
9894
size_t row_stride, size_t n_rows)
9995
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) {
10096
monitor_.Init("ellpack_page");
@@ -105,12 +101,11 @@ EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, b
105101
monitor_.Stop("InitCompressedData");
106102
}
107103

108-
EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts,
109-
const SparsePage &page, bool is_dense,
110-
size_t row_stride,
104+
EllpackPageImpl::EllpackPageImpl(DeviceOrd device,
105+
std::shared_ptr<common::HistogramCuts const> cuts,
106+
const SparsePage& page, bool is_dense, size_t row_stride,
111107
common::Span<FeatureType const> feature_types)
112-
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()),
113-
row_stride(row_stride) {
108+
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()), row_stride(row_stride) {
114109
this->InitCompressedData(device);
115110
this->CreateHistIndices(device, page, feature_types);
116111
}
@@ -127,9 +122,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
127122
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
128123
row_stride = GetRowStride(dmat);
129124
if (!param.hess.empty()) {
130-
cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess);
125+
cuts_ = std::make_shared<common::HistogramCuts>(
126+
common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess));
131127
} else {
132-
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
128+
cuts_ = std::make_shared<common::HistogramCuts>(common::DeviceSketch(ctx, dmat, param.max_bin));
133129
}
134130
monitor_.Stop("Quantiles");
135131

@@ -297,7 +293,7 @@ template <typename AdapterBatch>
297293
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd device, bool is_dense,
298294
common::Span<size_t> row_counts_span,
299295
common::Span<FeatureType const> feature_types, size_t row_stride,
300-
size_t n_rows, common::HistogramCuts const& cuts) {
296+
size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts) {
301297
dh::safe_cuda(cudaSetDevice(device.ordinal));
302298

303299
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
@@ -309,7 +305,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd de
309305
template EllpackPageImpl::EllpackPageImpl( \
310306
__BATCH_T batch, float missing, DeviceOrd device, bool is_dense, \
311307
common::Span<size_t> row_counts_span, common::Span<FeatureType const> feature_types, \
312-
size_t row_stride, size_t n_rows, common::HistogramCuts const& cuts);
308+
size_t row_stride, size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
313309

314310
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
315311
ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
@@ -359,7 +355,11 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const>
359355

360356
EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
361357
common::Span<FeatureType const> ft)
362-
: is_dense{page.IsDense()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{page.cut} {
358+
: is_dense{page.IsDense()},
359+
base_rowid{page.base_rowid},
360+
n_rows{page.Size()},
361+
// This makes a copy of the cut values.
362+
cuts_{std::make_shared<common::HistogramCuts>(page.cut)} {
363363
auto it = common::MakeIndexTransformIter(
364364
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
365365
row_stride = *std::max_element(it, it + page.Size());

src/data/ellpack_page.cuh

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@ struct EllpackDeviceAccessor {
2323
bool is_dense;
2424
/*! \brief Row length for ELLPACK, equal to number of features. */
2525
size_t row_stride;
26-
size_t base_rowid{};
27-
size_t n_rows{};
28-
common::CompressedIterator<uint32_t> gidx_iter;
26+
bst_idx_t base_rowid{0};
27+
bst_idx_t n_rows{0};
28+
common::CompressedIterator<std::uint32_t> gidx_iter;
2929
/*! \brief Minimum value for each feature. Size equals to number of features. */
30-
common::Span<const bst_float> min_fvalue;
30+
common::Span<const float> min_fvalue;
3131
/*! \brief Histogram cut pointers. Size equals to (number of features + 1). */
32-
common::Span<const uint32_t> feature_segments;
32+
common::Span<const std::uint32_t> feature_segments;
3333
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
34-
common::Span<const bst_float> gidx_fvalue_map;
34+
common::Span<const float> gidx_fvalue_map;
3535

3636
common::Span<const FeatureType> feature_types;
3737

38-
EllpackDeviceAccessor(DeviceOrd device, const common::HistogramCuts& cuts, bool is_dense,
39-
size_t row_stride, size_t base_rowid, size_t n_rows,
38+
EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts,
39+
bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows,
4040
common::CompressedIterator<uint32_t> gidx_iter,
4141
common::Span<FeatureType const> feature_types)
4242
: is_dense(is_dense),
@@ -46,16 +46,16 @@ struct EllpackDeviceAccessor {
4646
gidx_iter(gidx_iter),
4747
feature_types{feature_types} {
4848
if (device.IsCPU()) {
49-
gidx_fvalue_map = cuts.cut_values_.ConstHostSpan();
50-
feature_segments = cuts.cut_ptrs_.ConstHostSpan();
51-
min_fvalue = cuts.min_vals_.ConstHostSpan();
49+
gidx_fvalue_map = cuts->cut_values_.ConstHostSpan();
50+
feature_segments = cuts->cut_ptrs_.ConstHostSpan();
51+
min_fvalue = cuts->min_vals_.ConstHostSpan();
5252
} else {
53-
cuts.cut_values_.SetDevice(device);
54-
cuts.cut_ptrs_.SetDevice(device);
55-
cuts.min_vals_.SetDevice(device);
56-
gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan();
57-
feature_segments = cuts.cut_ptrs_.ConstDeviceSpan();
58-
min_fvalue = cuts.min_vals_.ConstDeviceSpan();
53+
cuts->cut_values_.SetDevice(device);
54+
cuts->cut_ptrs_.SetDevice(device);
55+
cuts->min_vals_.SetDevice(device);
56+
gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan();
57+
feature_segments = cuts->cut_ptrs_.ConstDeviceSpan();
58+
min_fvalue = cuts->min_vals_.ConstDeviceSpan();
5959
}
6060
}
6161
// Get a matrix element, uses binary search for look up Return NaN if missing
@@ -142,13 +142,14 @@ class EllpackPageImpl {
142142
* This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo
143143
* and the given number of rows.
144144
*/
145-
EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, bool is_dense, size_t row_stride,
146-
size_t n_rows);
145+
EllpackPageImpl(DeviceOrd device, std::shared_ptr<common::HistogramCuts const> cuts,
146+
bool is_dense, size_t row_stride, size_t n_rows);
147147
/*!
148148
* \brief Constructor used for external memory.
149149
*/
150-
EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, const SparsePage& page,
151-
bool is_dense, size_t row_stride, common::Span<FeatureType const> feature_types);
150+
EllpackPageImpl(DeviceOrd device, std::shared_ptr<common::HistogramCuts const> cuts,
151+
const SparsePage& page, bool is_dense, size_t row_stride,
152+
common::Span<FeatureType const> feature_types);
152153

153154
/*!
154155
* \brief Constructor from an existing DMatrix.
@@ -162,7 +163,7 @@ class EllpackPageImpl {
162163
explicit EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd device, bool is_dense,
163164
common::Span<size_t> row_counts_span,
164165
common::Span<FeatureType const> feature_types, size_t row_stride,
165-
size_t n_rows, common::HistogramCuts const& cuts);
166+
size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
166167
/**
167168
* \brief Constructor from an existing CPU gradient index.
168169
*/
@@ -194,16 +195,17 @@ class EllpackPageImpl {
194195
base_rowid = row_id;
195196
}
196197

197-
[[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; }
198-
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; }
198+
[[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; }
199+
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
200+
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) { cuts_ = cuts; }
199201

200202
/*! \return Estimation of memory cost of this page. */
201203
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
202204

203205

204206
/*! \brief Return the total number of symbols (total number of bins plus 1 for
205207
* not found). */
206-
[[nodiscard]] std::size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
208+
[[nodiscard]] std::size_t NumSymbols() const { return cuts_->TotalBins() + 1; }
207209

208210
[[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor(
209211
DeviceOrd device, common::Span<FeatureType const> feature_types = {}) const;
@@ -225,19 +227,18 @@ class EllpackPageImpl {
225227
*/
226228
void InitCompressedData(DeviceOrd device);
227229

228-
229-
public:
230+
public:
230231
/*! \brief Whether or not if the matrix is dense. */
231232
bool is_dense;
232233
/*! \brief Row length for ELLPACK. */
233234
size_t row_stride;
234-
size_t base_rowid{0};
235-
size_t n_rows{};
235+
bst_idx_t base_rowid{0};
236+
bst_idx_t n_rows{};
236237
/*! \brief global index of histogram, which is stored in ELLPACK format. */
237238
HostDeviceVector<common::CompressedByteT> gidx_buffer;
238239

239240
private:
240-
common::HistogramCuts cuts_;
241+
std::shared_ptr<common::HistogramCuts const> cuts_;
241242
common::Monitor monitor_;
242243
};
243244

src/data/ellpack_page.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class EllpackPage {
4949
[[nodiscard]] const EllpackPageImpl* Impl() const { return impl_.get(); }
5050
EllpackPageImpl* Impl() { return impl_.get(); }
5151

52-
[[nodiscard]] common::HistogramCuts& Cuts();
5352
[[nodiscard]] common::HistogramCuts const& Cuts() const;
5453

5554
private:

0 commit comments

Comments
 (0)