-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPSTLStream.h
182 lines (146 loc) · 5.99 KB
/
PSTLStream.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#pragma once
#ifdef DPCPP_BACKEND
#include <oneapi/dpl/algorithm>
#include <oneapi/dpl/execution>//<= why does it need tbb?
#include <oneapi/dpl/iterator>
#include <oneapi/dpl/random>
#else
#include <algorithm>
#include <execution>
#include <numeric>
#include <limits>
#endif
#include <iostream>
#include <stdexcept>
#include <string>
#include <vector>
#include <memory>
#if defined (DPCPP_BACKEND)
#include <CL/sycl.hpp>
using oneapi::dpl::counting_iterator;
#endif
namespace impl {
template <typename IntType>
class counting_iterator {
static_assert(std::numeric_limits<IntType>::is_integer, "Cannot instantiate counting_iterator with a non-integer type");
public:
using value_type = IntType;
using difference_type = typename std::make_signed<IntType>::type;
using pointer = IntType*;
using reference = IntType&;
using iterator_category = std::random_access_iterator_tag;
counting_iterator() : value(0) { }
explicit counting_iterator(IntType v) : value(v) { }
value_type operator*() const { return value; }
value_type operator[](difference_type n) const { return value + n; }
counting_iterator& operator++() { ++value; return *this; }
counting_iterator operator++(int) {
counting_iterator result{value};
++value;
return result;
}
counting_iterator& operator--() { --value; return *this; }
counting_iterator operator--(int) {
counting_iterator result{value};
--value;
return result;
}
counting_iterator& operator+=(difference_type n) { value += n; return *this; }
counting_iterator& operator-=(difference_type n) { value -= n; return *this; }
friend counting_iterator operator+(counting_iterator const& i, difference_type n) { return counting_iterator(i.value + n); }
friend counting_iterator operator+(difference_type n, counting_iterator const& i) { return counting_iterator(i.value + n); }
friend difference_type operator-(counting_iterator const& x, counting_iterator const& y) { return x.value - y.value; }
friend counting_iterator operator-(counting_iterator const& i, difference_type n) { return counting_iterator(i.value - n); }
friend bool operator==(counting_iterator const& x, counting_iterator const& y) { return x.value == y.value; }
friend bool operator!=(counting_iterator const& x, counting_iterator const& y) { return x.value != y.value; }
friend bool operator<(counting_iterator const& x, counting_iterator const& y) { return x.value < y.value; }
friend bool operator<=(counting_iterator const& x, counting_iterator const& y) { return x.value <= y.value; }
friend bool operator>(counting_iterator const& x, counting_iterator const& y) { return x.value > y.value; }
friend bool operator>=(counting_iterator const& x, counting_iterator const& y) { return x.value >= y.value; }
private:
IntType value;
};
} //impl
#define IMPLEMENTATION_STRING "PSTL"
#ifdef DPCPP_BACKEND
namespace pstl_impl = oneapi::dpl;
#else
namespace pstl_impl = std;
using namespace impl;
#endif
template <typename T> class Stream {
public:
virtual ~Stream(){}
// Kernels
// These must be blocking calls
virtual void copy() = 0;
virtual void mul(const T &s) = 0;
virtual void add() = 0;
virtual void triad(const T &s) = 0;
virtual T dot() = 0;
// Copy memory between host and device
virtual void init_arrays(T &&initA, T &&initB, T &&initC) = 0;
};
template <typename T, typename Policy, class Allocator> class PSTLStream : public Stream<T> {
protected:
// Device side refs
Policy &p;
std::vector<T, Allocator> a;
std::vector<T, Allocator> b;
std::vector<T, Allocator> c;
public:
PSTLStream(Policy &p_, const int N, const Allocator alloc) : p(p_), a(N, alloc), b(N, alloc), c(N, alloc)
{}
~PSTLStream()
{}
virtual void copy() override;
virtual void add() override;
virtual void mul(const T &s) override;
virtual void triad(const T &s) override;
virtual T dot() override;
virtual void init_arrays(T &&initA, T &&initB, T &&initC) override;
};
template <typename T, typename Policy, class Allocator>
void PSTLStream<T, Policy, Allocator>::init_arrays(T &&initA, T &&initB, T &&initC)
{
std::fill(p, a.begin(), a.end(), initA);
std::fill(p, b.begin(), b.end(), initB);
std::fill(p, c.begin(), c.end(), initC);
}
template <typename T, typename Policy, class Allocator>
void PSTLStream<T, Policy, Allocator>::copy()
{
std::copy(p, a.begin(), a.end(), c.begin());
}
template <typename T, typename Policy, class Allocator>
void PSTLStream<T, Policy, Allocator>::mul(const T &s_)
{
const T s = s_;
const int N = b.size();
std::for_each(p, counting_iterator(0), counting_iterator(N), [=, b_= b.data(), c_ = c.data() ](const auto i) { b_[i] = s*c_[i];});
}
template <typename T, typename Policy, class Allocator>
void PSTLStream<T, Policy, Allocator>::add()
{
const int N = c.size();
std::for_each(p, counting_iterator(0), counting_iterator(N), [a_ = a.data(), b_= b.data(), c_ = c.data()](const auto i) { c_[i] = a_[i] + b_[i];});
}
template <typename T, typename Policy, class Allocator>
void PSTLStream<T, Policy, Allocator>::triad(const T &s_)
{
const T s = s_;
const int N = a.size();
std::for_each(p, counting_iterator(0), counting_iterator(N), [=, a_ = a.data(), b_= b.data(), c_ = c.data()](const auto i) { a_[i] = b_[i] + s*c_[i];});
}
template <typename T, typename Policy, class Allocator>
T PSTLStream<T, Policy, Allocator>::dot()
{
T sum = pstl_impl::transform_reduce(p,
a.begin(),
a.end(),
b.begin(),
static_cast<T>(0.0),
pstl_impl::plus<T>(),
[=](const auto &ai, const auto &bi) { return ai*bi;} );
return sum;
}