forked from daphne-project/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAvgPoolForward.h
More file actions
44 lines (37 loc) · 2.61 KB
/
Copy pathAvgPoolForward.h
File metadata and controls
44 lines (37 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "Pooling.h"
// ****************************************************************************
// Struct for partial template specialization
// ****************************************************************************
template <class DTRes, class DTArg> struct AvgPoolForward {
static void apply(DTRes *&res, size_t &res_h, size_t &res_w, const DTArg *data, const size_t batch_size,
const size_t num_channels, const size_t img_h, const size_t img_w, const size_t pool_h,
const size_t pool_w, const size_t stride_h, const size_t stride_w, const size_t pad_h,
const size_t pad_w, DCTX(dctx)) = delete;
};
// ****************************************************************************
// Convenience function
// ****************************************************************************
template <class DTRes, class DTArg>
void avgPoolForward(DTRes *&res, size_t &res_h, size_t &res_w, const DTArg *data, const size_t batch_size,
const size_t num_channels, const size_t img_h, const size_t img_w, const size_t pool_h,
const size_t pool_w, const size_t stride_h, const size_t stride_w, const size_t pad_h,
const size_t pad_w, DCTX(dctx)) {
AvgPoolForward<DTRes, DTArg>::apply(res, res_h, res_w, data, batch_size, num_channels, img_h, img_w, pool_h, pool_w,
stride_h, stride_w, pad_h, pad_w, dctx);
}
// ****************************************************************************
// (Partial) template specializations for different data/value types
// ****************************************************************************
// ----------------------------------------------------------------------------
// DenseMatrix <- DenseMatrix
// ----------------------------------------------------------------------------
template <typename VTRes, typename VTArg> struct AvgPoolForward<DenseMatrix<VTRes>, DenseMatrix<VTArg>> {
static void apply(DenseMatrix<VTRes> *&res, size_t &res_h, size_t &res_w, const DenseMatrix<VTArg> *data,
const size_t batch_size, const size_t num_channels, const size_t img_h, const size_t img_w,
const size_t pool_h, const size_t pool_w, const size_t stride_h, const size_t stride_w,
const size_t pad_h, const size_t pad_w, DCTX(dctx)) {
NN::Pooling::Forward<NN::Pooling::AVG, DenseMatrix<VTRes>, DenseMatrix<VTArg>>::apply(
res, res_h, res_w, data, batch_size, num_channels, img_h, img_w, pool_h, pool_w, stride_h, stride_w, pad_h,
pad_w, dctx);
}
};