forked from daphne-project/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCondMatMatMat.h
More file actions
121 lines (101 loc) · 5.5 KB
/
Copy pathCondMatMatMat.h
File metadata and controls
121 lines (101 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/*
* Copyright 2021 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SRC_RUNTIME_LOCAL_KERNELS_CONDMATMATMAT_H
#define SRC_RUNTIME_LOCAL_KERNELS_CONDMATMATMAT_H
#include <runtime/local/context/DaphneContext.h>
#include <runtime/local/datastructures/DataObjectFactory.h>
#include <runtime/local/datastructures/DenseMatrix.h>
#include <runtime/local/datastructures/Matrix.h>
#include <sstream>
#include <stdexcept>
// ****************************************************************************
// Struct for partial template specialization
// ****************************************************************************
template <class DTRes, class DTCond, class DTThen, class DTElse> struct CondMatMatMat {
static void apply(DTRes *&res, const DTCond *cond, const DTThen *thenVal, const DTElse *elseVal,
DCTX(ctx)) = delete;
};
// ****************************************************************************
// Convenience function
// ****************************************************************************
template <class DTRes, class DTCond, class DTThen, class DTElse>
void condMatMatMat(DTRes *&res, const DTCond *cond, const DTThen *thenVal, const DTElse *elseVal, DCTX(ctx)) {
CondMatMatMat<DTRes, DTCond, DTThen, DTElse>::apply(res, cond, thenVal, elseVal, ctx);
}
// ****************************************************************************
// (Partial) template specializations for different data/value types
// ****************************************************************************
// ----------------------------------------------------------------------------
// DenseMatrix <- DenseMatrix, DenseMatrix, DenseMatrix
// ----------------------------------------------------------------------------
template <typename VTVal, typename VTCond>
struct CondMatMatMat<DenseMatrix<VTVal>, DenseMatrix<VTCond>, DenseMatrix<VTVal>, DenseMatrix<VTVal>> {
static void apply(DenseMatrix<VTVal> *&res, const DenseMatrix<VTCond> *cond, const DenseMatrix<VTVal> *thenVal,
const DenseMatrix<VTVal> *elseVal, DCTX(ctx)) {
const size_t numRows = cond->getNumRows();
const size_t numCols = cond->getNumCols();
if (numRows != thenVal->getNumRows() || numRows != elseVal->getNumRows() || numCols != thenVal->getNumCols() ||
numCols != elseVal->getNumCols())
throw std::runtime_error("CondMatMatMat: condition/then/else "
"matrices must have the same shape");
if (res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VTVal>>(numRows, numCols, false);
VTVal *valuesRes = res->getValues();
const VTCond *valuesCond = cond->getValues();
const VTVal *valuesThen = thenVal->getValues();
const VTVal *valuesElse = elseVal->getValues();
const size_t rowSkipRes = res->getRowSkip();
const size_t rowSkipCond = cond->getRowSkip();
const size_t rowSkipThen = thenVal->getRowSkip();
const size_t rowSkipElse = elseVal->getRowSkip();
for (size_t r = 0; r < numRows; r++) {
for (size_t c = 0; c < numCols; c++)
valuesRes[c] = static_cast<bool>(valuesCond[c]) ? valuesThen[c] : valuesElse[c];
valuesRes += rowSkipRes;
valuesCond += rowSkipCond;
valuesThen += rowSkipThen;
valuesElse += rowSkipElse;
}
}
};
// ----------------------------------------------------------------------------
// Matrix <- Matrix, Matrix, Matrix
// ----------------------------------------------------------------------------
template <typename VTVal, typename VTCond>
struct CondMatMatMat<Matrix<VTVal>, Matrix<VTCond>, Matrix<VTVal>, Matrix<VTVal>> {
static void apply(Matrix<VTVal> *&res, const Matrix<VTCond> *cond, const Matrix<VTVal> *thenVal,
const Matrix<VTVal> *elseVal, DCTX(ctx)) {
const size_t numRows = cond->getNumRows();
const size_t numCols = cond->getNumCols();
if (numRows != thenVal->getNumRows() || numRows != elseVal->getNumRows() || numCols != thenVal->getNumCols() ||
numCols != elseVal->getNumCols()) {
std::ostringstream errMsg;
errMsg << "CondMatMatMat: condition/then/else matrices must have "
"the same shape but have ("
<< numRows << "," << numCols << "), (" << thenVal->getNumRows() << "," << thenVal->getNumCols()
<< ") and (" << elseVal->getNumRows() << "," << elseVal->getNumCols() << ")";
throw std::runtime_error(errMsg.str());
}
if (res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VTVal>>(numRows, numCols, false);
res->prepareAppend();
for (size_t r = 0; r < numRows; ++r)
for (size_t c = 0; c < numCols; ++c)
res->append(r, c, static_cast<bool>(cond->get(r, c)) ? thenVal->get(r, c) : elseVal->get(r, c));
res->finishAppend();
}
};
#endif // SRC_RUNTIME_LOCAL_KERNELS_CONDMATMATMAT_H