Skip to content

Commit

Permalink
Keep one version of gemm
Browse files Browse the repository at this point in the history
There's not enough difference/evidence for the bloat of many variants.

* bench/bench-gemm.cc: Benchmark all the variants previously used as ra::gemm.
* ra/ra.hh (gemm): Use wrank/wrank variant for all types/sizes.
  • Loading branch information
lloda committed Dec 5, 2023
1 parent d84086c commit e9dd24a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 41 deletions.
80 changes: 74 additions & 6 deletions bench/bench-gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,44 @@
#include "ra/bench.hh"

using std::cout, std::endl, std::setw, std::setprecision, ra::TestRecorder;
using ra::Small, ra::ViewBig, ra::Unique, ra::dim_t;
using ra::Small, ra::ViewBig, ra::Unique, ra::dim_t, ra::all;
using real = double;

// FIXME variants of IJ that were at some point used in gemm()
void
gemm1(auto && a, auto && b, auto && c)
{
for_each(ra::wrank<1, 2, 1>(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { ra::maybe_fma(a, b, c); })),
RA_FWD(a), RA_FWD(b), RA_FWD(c));
}
void
gemm2(auto && a, auto && b, auto && c)
{
dim_t K=a.len(1);
for (int k=0; k<K; ++k) {
c += from(std::multiplies<>(), a(all, k), b(k));
}
}
void
gemm3(auto && a, auto && b, auto && c)
{
dim_t K=a.len(1);
for (int k=0; k<K; ++k) {
for_each(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto && c) { ra::maybe_fma(a, b, c); }), a(all, k), b(k), c);
}
}
// variant of K, same. Overwrites c
void
gemm4(auto && a, auto && b, auto && c)
{
dim_t M=a.len(0), N=b.len(1);
for (int i=0; i<M; ++i) {
for (int j=0; j<N; ++j) {
c(i, j) = dot(a(i), b(all, j));
}
}
}

// -------------------
// variants of the defaults, should be slower if the default is well picked.
// -------------------
Expand All @@ -38,12 +73,12 @@ gemm_block_3(ra::ViewBig<A, 2> const & a, ra::ViewBig<B, 2> const & b, ra::ViewB
gemm_block_3(a(ra::iota(m-m/2, m/2)), b, c(ra::iota(m-m/2, m/2)));
// split b's columns
} else if (n>=max(m, p)) {
gemm_block_3(a, b(ra::all, ra::iota(n/2)), c(ra::all, ra::iota(n/2)));
gemm_block_3(a, b(ra::all, ra::iota(n-n/2, n/2)), c(ra::all, ra::iota(n-n/2, n/2)));
gemm_block_3(a, b(all, ra::iota(n/2)), c(all, ra::iota(n/2)));
gemm_block_3(a, b(all, ra::iota(n-n/2, n/2)), c(all, ra::iota(n-n/2, n/2)));
// split a's columns and b's rows
} else {
gemm_block_3(a(ra::all, ra::iota(p/2)), b(ra::iota(p/2)), c);
gemm_block_3(a(ra::all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
gemm_block_3(a(all, ra::iota(p/2)), b(ra::iota(p/2)), c);
gemm_block_3(a(all, ra::iota(p-p/2, p/2)), b(ra::iota(p-p/2, p/2)), c);
}
}

Expand Down Expand Up @@ -127,6 +162,34 @@ int main()
cout << "FP_FAST_FMA is " << FP_FAST_FMA << endl;
cout << "RA_DO_FMA is " << RA_DO_FMA << endl;

auto gemm_ply1 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
gemm1(a, b, c);
return c;
};

auto gemm_ply2 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
gemm2(a, b, c);
return c;
};

auto gemm_ply3 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
gemm3(a, b, c);
return c;
};

auto gemm_ply4 = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, ra::none);
gemm4(a, b, c);
return c;
};

auto gemm_block = [&](auto const & a, auto const & b)
{
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({a.len(0), b.len(1)}, 0);
Expand All @@ -141,7 +204,7 @@ int main()
ra::Big<decltype(a(0, 0)*b(0, 0)), 2> c({M, N}, ra::none);
for (dim_t i=0; i<M; ++i) {
for (dim_t j=0; j<N; ++j) {
c(i, j) = dot(a(i), b(ra::all, j));
c(i, j) = dot(a(i), b(all, j));
}
}
return c;
Expand Down Expand Up @@ -234,9 +297,14 @@ int main()
#if RA_USE_BLAS==1
bench(gemm_blas, "blas");
#endif
bench(gemm_ply1, "ply1");
bench(gemm_ply2, "ply2");
bench(gemm_ply3, "ply3");
bench(gemm_ply4, "ply4");
bench([&](auto const & a, auto const & b) { return gemm(a, b); }, "default");
};

bench_all(3, 4, 4, 4, 100);
bench_all(3, 10, 10, 10, 100);
bench_all(2, 100, 100, 100, 10);
bench_all(2, 500, 400, 500, 1);
Expand Down
53 changes: 18 additions & 35 deletions ra/ra.hh
Original file line number Diff line number Diff line change
Expand Up @@ -483,31 +483,27 @@ prod(auto && a)
constexpr auto reduce_sqrm(auto && a) { return sum(sqrm(a)); }
constexpr auto norm2(auto && a) { return std::sqrt(reduce_sqrm(a)); }

constexpr void
maybe_fma(auto && a, auto && b, auto & c)
{
if constexpr (1==RA_DO_FMA) { c = fma(a, b, c); } else { c += a*b; }
}

constexpr void
maybe_fma_conj(auto && a, auto && b, auto & c)
{
if constexpr (1==RA_DO_FMA) { c = fma_conj(a, b, c); } else { c += conj(a)*b; }
}
#if 1==RA_DO_FMA
constexpr void maybe_fma(auto && a, auto && b, auto & c) { c = fma(a, b, c); };
constexpr void maybe_fma_conj(auto && a, auto && b, auto & c) { c = fma_conj(a, b, c); };
#else
constexpr void maybe_fma(auto && a, auto && b, auto & c) { c += a*b; };
constexpr void maybe_fma_conj(auto && a, auto && b, auto & c) { c += conj(a)*b; };
#endif

constexpr auto
dot(auto && a, auto && b)
{
std::decay_t<decltype(VALUE(a) * VALUE(b))> c(0.);
for_each([&c](auto && a, auto && b) { maybe_fma(a, b, c); }, a, b);
for_each([&c](auto && a, auto && b) { maybe_fma(a, b, c); }, RA_FWD(a), RA_FWD(b));
return c;
}

constexpr auto
cdot(auto && a, auto && b)
{
std::decay_t<decltype(conj(VALUE(a)) * VALUE(b))> c(0.);
for_each([&c](auto && a, auto && b) { maybe_fma_conj(a, b, c); }, a, b);
for_each([&c](auto && a, auto && b) { maybe_fma_conj(a, b, c); }, RA_FWD(a), RA_FWD(b));
return c;
}

Expand All @@ -519,37 +515,24 @@ normv(auto const & a)
return b;
}

// FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
// FIXME benchmark w/o allocation and do Small/Big versions if it's worth it (see bench-gemm.cc)

constexpr void
gemm(auto const & a, auto const & b, auto & c)
{
for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { maybe_fma(a, b, c); })),
c, a, b);
for_each(ra::wrank<1, 2, 1>(ra::wrank<0, 1, 1>([](auto && a, auto && b, auto & c) { maybe_fma(a, b, c); })),
RA_FWD(a), RA_FWD(b), RA_FWD(c));
}

// default for row-major x row-major.
// FIXME bench these exact variants in bench-gemm.cc, plus sizes.
constexpr auto
gemm(auto const & a, auto const & b)
{
dim_t M=a.len(0), N=b.len(1), K=a.len(1);
dim_t M=a.len(0), N=b.len(1);
using T = decltype(VALUE(a)*VALUE(b));
using MMTYPE = decltype(from(std::multiplies<>(), a(all, 0), b(0)));
if (K<M+N) {
auto c = with_shape<MMTYPE>({M, N}, T());
for (int k=0; k<K; ++k) {
c += from(std::multiplies<>(), a(all, k), b(k)); // FIXME fma?
}
return c;
} else {
auto c = with_shape<MMTYPE>({M, N}, ra::none);
for (int i=0; i<M; ++i) {
for (int j=0; j<N; ++j) {
c(i, j) = dot(a(i), b(all, j));
}
}
return c;
}
auto c = with_shape<MMTYPE>({M, N}, T());
gemm(a, b, c);
return c;
}

constexpr auto
Expand Down Expand Up @@ -579,7 +562,7 @@ gemv(auto const & a, auto const & b)


// --------------------
// wedge product and cross product
// wedge product and cross product. FIXME seriously
// --------------------

namespace mp {
Expand Down

0 comments on commit e9dd24a

Please sign in to comment.