forked from daphne-project/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMap.h
More file actions
98 lines (78 loc) · 3.83 KB
/
Copy pathMap.h
File metadata and controls
98 lines (78 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
* Copyright 2021 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SRC_RUNTIME_LOCAL_KERNELS_MAP_H
#define SRC_RUNTIME_LOCAL_KERNELS_MAP_H
#include <runtime/local/context/DaphneContext.h>
#include <runtime/local/datastructures/DataObjectFactory.h>
#include <runtime/local/datastructures/DenseMatrix.h>
#include <runtime/local/datastructures/Matrix.h>
#include <algorithm>
// ****************************************************************************
// Struct for partial template specialization
// ****************************************************************************
template <class DTRes, class DTArg> struct Map {
// We could have a more specialized function pointer here i.e.
// (DTRes::VT)(*func)(DTArg::VT). The problem is that this is currently not
// supported by kernels.json.
static void apply(DTRes *&res, const DTArg *arg, void *func, DCTX(ctx)) = delete;
};
// ****************************************************************************
// Convenience function
// ****************************************************************************
template <class DTRes, class DTArg> void map(DTRes *&res, const DTArg *arg, void *func, DCTX(ctx)) {
Map<DTRes, DTArg>::apply(res, arg, func, ctx);
}
// ****************************************************************************
// (Partial) template specializations for different data/value types
// ****************************************************************************
// ----------------------------------------------------------------------------
// DenseMatrix
// ----------------------------------------------------------------------------
template <typename VTRes, typename VTArg> struct Map<DenseMatrix<VTRes>, DenseMatrix<VTArg>> {
static void apply(DenseMatrix<VTRes> *&res, const DenseMatrix<VTArg> *arg, void *func, DCTX(ctx)) {
const size_t numRows = arg->getNumRows();
const size_t numCols = arg->getNumCols();
if (res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VTRes>>(numRows, numCols, false);
auto udf = reinterpret_cast<VTRes (*)(VTArg)>(func);
const VTArg *valuesArg = arg->getValues();
VTRes *valuesRes = res->getValues();
for (size_t r = 0; r < numRows; r++) {
for (size_t c = 0; c < numCols; c++)
valuesRes[c] = udf(valuesArg[c]);
valuesArg += arg->getRowSkip();
valuesRes += res->getRowSkip();
}
}
};
// ----------------------------------------------------------------------------
// Matrix
// ----------------------------------------------------------------------------
template <typename VTRes, typename VTArg> struct Map<Matrix<VTRes>, Matrix<VTArg>> {
static void apply(Matrix<VTRes> *&res, const Matrix<VTArg> *arg, void *func, DCTX(ctx)) {
const size_t numRows = arg->getNumRows();
const size_t numCols = arg->getNumCols();
if (res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VTRes>>(numRows, numCols, false);
auto udf = reinterpret_cast<VTRes (*)(VTArg)>(func);
res->prepareAppend();
for (size_t r = 0; r < numRows; ++r)
for (size_t c = 0; c < numCols; ++c)
res->append(r, c, udf(arg->get(r, c)));
res->finishAppend();
}
};
#endif // SRC_RUNTIME_LOCAL_KERNELS_MAP_H