-
Notifications
You must be signed in to change notification settings - Fork 107
/
Copy pathgauge_shift.cu
53 lines (43 loc) · 1.57 KB
/
gauge_shift.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <tunable_nd.h>
#include <instantiate.h>
#include <gauge_field.h>
#include <kernels/gauge_shift.cuh>
namespace quda
{
template <typename Float, int nColor, QudaReconstructType recon_u> class ShiftGauge : public TunableKernel3D
{
GaugeField &out;
const GaugeField ∈
const array<int, 4> &dx;
unsigned int minThreads() const { return in.VolumeCB(); }
public:
ShiftGauge(GaugeField &out, const GaugeField &in, const array<int, 4> &dx) :
TunableKernel3D(in, 2, in.Geometry()), out(out), in(in), dx(dx)
{
strcat(aux, ",shift=");
for (int i = 0; i < in.Ndim(); i++) { strcat(aux, std::to_string(dx[i]).c_str()); }
strcat(aux, comm_dim_partitioned_string());
apply(device::get_default_stream());
}
void apply(const qudaStream_t &stream)
{
TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
launch<GaugeShift>(tp, stream, GaugeShiftArg<Float, nColor, recon_u>(out, in, dx));
}
void preTune() { }
void postTune() { }
long long flops() const { return in.Volume() * 4; }
long long bytes() const { return in.Bytes(); }
};
void gaugeShift(GaugeField &out, const GaugeField &in, const array<int, 4> &dx)
{
checkPrecision(in, out);
checkLocation(in, out);
checkReconstruct(in, out);
if (out.Geometry() != in.Geometry()) {
errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
}
// gauge field must be passed as first argument so we peel off its reconstruct type
instantiate<ShiftGauge>(out, in, dx);
}
} // namespace quda