Skip to content

Commit

Permalink
Keep one version of gemm
Browse files Browse the repository at this point in the history
There's not enough difference/evidence for the bloat of many variants.

* bench/bench-gemm.cc: Benchmark all the variants previously used as ra::gemm.
* ra/ra.hh (gemm): Use wrank/wrank variant for all types/sizes.
  • Loading branch information
lloda committed Dec 5, 2023
1 parent d84086c commit 47e9fa0
Show file tree
Hide file tree
Showing 7 changed files with 1,075 additions and 1,033 deletions.
1 change: 0 additions & 1 deletion bench/bench-dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ int main()
TestRecorder tr;
tr.o.width(6);
tr.o.precision(4);
cout << "FP_FAST_FMA is " << FP_FAST_FMA << endl;
cout << "RA_DO_FMA is " << RA_DO_FMA << endl;

std::random_device rand;
Expand Down
81 changes: 74 additions & 7 deletions bench/bench-gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,44 @@
#include "ra/bench.hh"

using std::cout, std::endl, std::setw, std::setprecision, ra::TestRecorder;
using ra::Small, ra::ViewBig, ra::Unique, ra::dim_t;
using ra::Small, ra::ViewBig, ra::Unique, ra::dim_t, ra::all;
using real = double;

// FIXME variants of IJ that were at some point used in gemm()
void
gemm1(auto && a, auto && b, auto && c)
{
for_each(ra::wrank<1, 2, 1>(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { ra::maybe_fma(a, b, c); })),
RA_FWD(a), RA_FWD(b), RA_FWD(c));
}
void
gemm2(auto && a, auto && b, auto && c)
{
dim_t K=a.len(1);
for (int k=0; k<K; ++k) {
c += from(std::multiplies<>(), a(all, k), b(k));
}
}
void
gemm3(auto && a, auto && b, auto && c)
{
dim_t K=a.len(1);
for (int k=0; k<K; ++k) {
for_each(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto && c) { ra::maybe_fma(a, b, c); }), a(all, k), b(k), c);
}
}
// variant of K, same. Overwrites c
void
gemm4(auto && a, auto && b, auto && c)
{
dim_t M=a.len(0), N=b.len(1);
for (int i=0; i<M; ++i) {
for (int j=0; j<N; ++j) {
c(i, j) = dot(a(i), b(all, j));
}
}
}

// -------------------
// variants of the defaults, should be slower if the default is well picked.
// -------------------
Expand All @@ -38,12 +73,12 @@ gemm_block_3(ra::ViewBig<A, 2> const & a, ra::ViewBig<B, 2> const & b, ra::ViewB
gemm_block_3(a(ra::iota(m-m/2, m/2)), b, c(ra::iota(m-m/2, m/2)));
// split b's columns
} else if (n>=max(m, p)) {
gemm_block_3(a, b(ra::all, ra::iota(n/2)), c(ra::all, ra::iota(n/2)));
gemm_block_3(a, b(ra::all, ra::iota(n-n/2, n/2)), c(ra::all, ra::iota(n-n/2, n/2)));
gemm_block_3(a, b(all, ra::iota(n/2)), c(all, ra::iota(n/2)));
gemm_block_3(a, b(all, ra::iota(n-n/2, n/2)), c(all, ra::iota(n-n/2, n/2)));
// split a's columns and b's rows
} else {
gemm_block_3(a(ra::all, ra::iota(p/2)), b(ra::iota(p/2)), c);
gemm_block_3(a(ra::all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
gemm_block_3(a(all, ra::iota(p/2)), b(ra::iota(p/2)), c);
gemm_block_3(a(all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
}
}

Expand Down Expand Up @@ -124,9 +159,36 @@ gemm_blas(ra::ViewBig<double, 2> const & a, ra::ViewBig<double, 2> const & b)
int main()
{
TestRecorder tr(std::cout);
cout << "FP_FAST_FMA is " << FP_FAST_FMA << endl;
cout << "RA_DO_FMA is " << RA_DO_FMA << endl;

auto gemm_ply1 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
gemm1(a, b, c);
return c;
};

auto gemm_ply2 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
gemm2(a, b, c);
return c;
};

auto gemm_ply3 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
gemm3(a, b, c);
return c;
};

auto gemm_ply4 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, ra::none);
gemm4(a, b, c);
return c;
};

auto gemm_block = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
Expand All @@ -141,7 +203,7 @@ int main()
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({M, N}, ra::none);
for (dim_t i=0; i<M; ++i) {
for (dim_t j=0; j<N; ++j) {
c(i, j) = dot(a(i), b(ra::all, j));
c(i, j) = dot(a(i), b(all, j));
}
}
return c;
Expand Down Expand Up @@ -234,9 +296,14 @@ int main()
#if RA_USE_BLAS==1
bench(gemm_blas, "blas");
#endif
bench(gemm_ply1, "ply1");
bench(gemm_ply2, "ply2");
bench(gemm_ply3, "ply3");
bench(gemm_ply4, "ply4");
bench([&](auto const & a, auto const & b) { return gemm(a, b); }, "default");
};

bench_all(3, 4, 4, 4, 100);
bench_all(3, 10, 10, 10, 100);
bench_all(2, 100, 100, 100, 10);
bench_all(2, 500, 400, 500, 1);
Expand Down
1 change: 0 additions & 1 deletion bench/bench-gemv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ enum trans_t { NOTRANS, TRANS };
int main()
{
TestRecorder tr(std::cout);
cout << "FP_FAST_FMA is " << FP_FAST_FMA << endl;
cout << "RA_DO_FMA is " << RA_DO_FMA << endl;

auto gemv_i = [&](auto const & a, auto const & b)
Expand Down
Loading

0 comments on commit 47e9fa0

Please sign in to comment.