forked from daphne-project/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathExtractRow.h
More file actions
270 lines (239 loc) · 12.7 KB
/
Copy pathExtractRow.h
File metadata and controls
270 lines (239 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
/*
* Copyright 2021 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SRC_RUNTIME_LOCAL_KERNELS_EXTRACTROW_H
#define SRC_RUNTIME_LOCAL_KERNELS_EXTRACTROW_H
#include <runtime/local/context/DaphneContext.h>
#include <runtime/local/datastructures/DataObjectFactory.h>
#include <runtime/local/datastructures/DenseMatrix.h>
#include <runtime/local/datastructures/Frame.h>
#include <runtime/local/datastructures/Matrix.h>
#include <runtime/local/datastructures/ValueTypeCode.h>
#include <runtime/local/datastructures/ValueTypeUtils.h>
#include <sstream>
#include <stdexcept>
#include <cmath>
#include <cstddef>
#include <cstdint>
// ****************************************************************************
// Struct for partial template specialization
// ****************************************************************************
template <class DTRes, class DTArg, typename VTSel> struct ExtractRow {
static void apply(DTRes *&res, const DTArg *arg, const DenseMatrix<VTSel> *sel, DCTX(ctx)) = delete;
};
// ****************************************************************************
// Convenience function
// ****************************************************************************
template <class DTRes, class DTArg, typename VTSel>
void extractRow(DTRes *&res, const DTArg *arg, const DenseMatrix<VTSel> *sel, DCTX(ctx)) {
ExtractRow<DTRes, DTArg, VTSel>::apply(res, arg, sel, ctx);
}
// ****************************************************************************
// Boundary validation
// ****************************************************************************
// index boundaries are verified later for performance
#define VALIDATE_ARGS(numColsSel) \
if (numColsSel != 1) { \
std::ostringstream errMsg; \
errMsg << "invalid argument passed to ExtractRow: column selection " \
"must be given as column matrix but has '" \
<< numColsSel << "' columns instead of one"; \
throw std::runtime_error(errMsg.str()); \
}
// ****************************************************************************
// (Partial) template specializations for different data/value types
// ****************************************************************************
// ----------------------------------------------------------------------------
// Frame <- Frame
// ----------------------------------------------------------------------------
// 0 (row-wise) or 1 (column-wise)
#define EXTRACTROW_FRAME_MODE 0
template <typename VTSel> struct ExtractRow<Frame, Frame, VTSel> {
static void apply(Frame *&res, const Frame *arg, const DenseMatrix<VTSel> *sel, DCTX(ctx)) {
VALIDATE_ARGS(sel->getNumCols());
const size_t numRowsSel = sel->getNumRows();
const size_t numCols = arg->getNumCols();
const size_t numRowsArg = arg->getNumRows();
const ValueTypeCode *schema = arg->getSchema();
#if EXTRACTROW_FRAME_MODE == 0
// Add some padding due to stores in units of 8 bytes (see below). This
// formula is a little pessimistic, though.
const size_t numRowsResAlloc = numRowsSel + sizeof(uint64_t) / sizeof(uint8_t) - 1;
#elif EXTRACTROW_FRAME_MODE == 1
const size_t numRowsResAlloc = numRowsSel;
#endif
if (res == nullptr)
res = DataObjectFactory::create<Frame>(numRowsResAlloc, numCols, schema, arg->getLabels(), false);
const VTSel *valuesSel = sel->getValues();
#if EXTRACTROW_FRAME_MODE == 0
// Some information on each column.
const auto elementSizes = std::make_unique<size_t[]>(numCols);
const auto argCols = std::make_unique<const uint8_t *[]>(numCols);
auto resCols = std::make_unique<uint8_t *[]>(numCols);
// Initialize information on each column.
for (size_t c = 0; c < numCols; c++) {
elementSizes[c] = ValueTypeUtils::sizeOf(schema[c]);
argCols[c] = reinterpret_cast<const uint8_t *>(arg->getColumnRaw(c));
resCols[c] = reinterpret_cast<uint8_t *>(res->getColumnRaw(c));
}
// Actual filtering.
for (size_t r = 0; r < numRowsSel; r++) {
const size_t pos = valuesSel[r];
if (valuesSel[r] < 0 || numRowsArg <= pos) {
std::ostringstream errMsg;
errMsg << "invalid argument '" << valuesSel[r]
<< "' passed to ExtractRow: "
"out of bounds for frame with row boundaries '[0, "
<< numRowsArg << ")'";
throw std::out_of_range(errMsg.str());
}
for (size_t c = 0; c < numCols; c++) {
if (schema[c] == ValueTypeCode::STR) {
// Handle std::string column
*reinterpret_cast<std::string *>(resCols[c]) =
*reinterpret_cast<const std::string *>(argCols[c] + pos * elementSizes[c]);
resCols[c] += elementSizes[c];
} else {
// We always copy in units of 8 bytes (uint64_t). If the
// actual element size is lower, the superfluous bytes will
// be overwritten by the next match. With this approach, we
// do not need to call memcpy for each element, nor
// interpret the types for a L/S of fitting size.
// TODO Don't multiply by elementSize, but left-shift by
// ld(elementSize).
*reinterpret_cast<uint64_t *>(resCols[c]) =
*reinterpret_cast<const uint64_t *>(argCols[c] + pos * elementSizes[c]);
resCols[c] += elementSizes[c];
}
}
}
res->shrinkNumRows(numRowsSel);
#elif EXTRACTROW_FRAME_MODE == 1
// TODO Implement a columnar approach.
#endif
}
};
#undef EXTRACTROW_FRAME_MODE
// ----------------------------------------------------------------------------
// DenseMatrix <- DenseMatrix
// ----------------------------------------------------------------------------
template <typename VT, typename VTSel> struct ExtractRow<DenseMatrix<VT>, DenseMatrix<VT>, VTSel> {
static void apply(DenseMatrix<VT> *&res, const DenseMatrix<VT> *arg, const DenseMatrix<VTSel> *sel, DCTX(ctx)) {
// input validation
if (arg == nullptr)
throw std::runtime_error("invalid argument passed to ExtractRow on "
"dense matrix: arg cannot be null");
if (sel == nullptr)
throw std::runtime_error("invalid argument passed to ExtractRow on dense matrix: "
"rowIdxs sel cannot be null");
VALIDATE_ARGS(sel->getNumCols());
const size_t numRowsSel = sel->getNumRows();
const size_t numRowsArg = arg->getNumRows();
const size_t numColsArg = arg->getNumCols();
if (res == nullptr) {
res = DataObjectFactory::create<DenseMatrix<VT>>(numRowsSel, numColsArg, false);
} else if (res->getNumRows() != numRowsSel || res->getNumCols() != numColsArg) {
// TODO what is the best strategy: throw a warning or just
// re-allocate?
std::ostringstream errMsg;
errMsg << "invalid argument passed to ExtractRow on dense matrix: "
"res was not null, but given res has wrong dimensions "
<< res->getNumRows() << "x" << res->getNumCols() << " instead of " << numRowsSel << "x"
<< numColsArg;
throw std::runtime_error(errMsg.str());
}
// Main Logic
VT *allUpdatedValues = res->getValues();
const VTSel *valuesSel = sel->getValues();
for (size_t r = 0; r < numRowsSel; r++) {
const VTSel valSelectedRow = valuesSel[r]; // only one column
// TODO For performance reasons, we might skip such checks or make
// them optional somehow, but it is okay for now.
if (std::isnan(valSelectedRow)) {
std::ostringstream errMsg;
errMsg << "invalid argument passed to ExtractRow on dense "
"matrix: rowIdxs sel value at index "
<< r << " is NaN";
throw std::runtime_error(errMsg.str());
} else if (valSelectedRow < 0 || numRowsArg <= static_cast<const size_t>(valSelectedRow)) {
std::ostringstream errMsg;
errMsg << "invalid argument '" << valSelectedRow
<< "' passed to ExtractRow: out of bounds for "
"matrix with row boundaries '[0, "
<< numRowsArg << ")'";
throw std::out_of_range(errMsg.str());
} else {
const VT *allValues = arg->getValues() + static_cast<const size_t>(valSelectedRow) * arg->getRowSkip();
for (size_t c = 0; c < numColsArg; c++) {
allUpdatedValues[c] = allValues[c];
}
allUpdatedValues += res->getRowSkip();
}
}
}
};
// ----------------------------------------------------------------------------
// Matrix <- Matrix
// ----------------------------------------------------------------------------
template <typename VT, typename VTSel> struct ExtractRow<Matrix<VT>, Matrix<VT>, VTSel> {
static void apply(Matrix<VT> *&res, const Matrix<VT> *arg, const Matrix<VTSel> *sel, DCTX(ctx)) {
// input validation
if (arg == nullptr)
throw std::runtime_error("invalid argument passed to ExtractRow on "
"dense matrix: arg cannot be null");
if (sel == nullptr)
throw std::runtime_error("invalid argument passed to ExtractRow on dense matrix: "
"rowIdxs sel cannot be null");
VALIDATE_ARGS(sel->getNumCols());
const size_t numRowsSel = sel->getNumRows();
const size_t numRowsArg = arg->getNumRows();
const size_t numColsArg = arg->getNumCols();
if (res == nullptr) {
res = DataObjectFactory::create<DenseMatrix<VT>>(numRowsSel, numColsArg, false);
} else if (res->getNumRows() != numRowsSel || res->getNumCols() != numColsArg) {
std::ostringstream errMsg;
errMsg << "invalid argument passed to ExtractRow on dense matrix: "
"res was not null, but given res has wrong dimensions "
<< res->getNumRows() << "x" << res->getNumCols() << " instead of " << numRowsSel << "x"
<< numColsArg;
throw std::runtime_error(errMsg.str());
}
// Main Logic
res->prepareAppend();
for (size_t r = 0; r < numRowsSel; ++r) {
const VTSel valSelectedRow = sel->get(r, 0); // only one column
if (std::isnan(valSelectedRow)) {
std::ostringstream errMsg;
errMsg << "invalid argument passed to ExtractRow on dense "
"matrix: rowIdxs sel value at index "
<< r << " is NaN";
throw std::runtime_error(errMsg.str());
} else if (valSelectedRow < 0 || numRowsArg <= static_cast<const size_t>(valSelectedRow)) {
std::ostringstream errMsg;
errMsg << "invalid argument '" << valSelectedRow
<< "' passed to ExtractRow: out of bounds for "
"matrix with row boundaries '[0, "
<< numRowsArg << ")'";
throw std::out_of_range(errMsg.str());
} else {
for (size_t c = 0; c < numColsArg; ++c)
res->append(r, c, arg->get(static_cast<const size_t>(valSelectedRow), c));
}
}
res->finishAppend();
}
};
#undef VALIDATE_ARGS
#endif // SRC_RUNTIME_LOCAL_KERNELS_EXTRACTROW_H