Skip to content

Commit

Permalink
WIP SYCL parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
Iluvmagick committed Sep 30, 2024
1 parent ff2a92c commit 3a667eb
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 11 deletions.
4 changes: 3 additions & 1 deletion parallel-crypto3/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ if (${FORCE_COLORED_OUTPUT})
endif ()
endif ()

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# The file compile_commands.json is generated in build directory, so LSP could
# pick it up and guess all include paths, defines and other stuff.
# If Nix is used, LSP could not guess the locations of implicit include
# directories, so we need to include them explicitly.
if(CMAKE_EXPORT_COMPILE_COMMANDS)
set(CMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES
set(CMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES
${CMAKE_CXX_IMPLICIT_INCLUDE_DIRECTORIES})
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
// SOFTWARE.
//---------------------------------------------------------------------------//

#ifndef CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP
#define CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP
#pragma once

#include <algorithm>
#include <memory>
Expand All @@ -38,6 +37,7 @@

#include <nil/actor/core/thread_pool.hpp>
#include <nil/actor/core/parallelization_utils.hpp>
#include <nil/actor/core/sycl_parallelization_utils.hpp>

namespace nil {
namespace crypto3 {
Expand Down Expand Up @@ -83,7 +83,7 @@ namespace nil {

// swapping in place (from Storer's book)
// We can parallelize this look, since k and rk are pairs, they will never intersect.
nil::crypto3::parallel_for(0, n,
sycl_parallel_for(0, n,
[logn, &a](std::size_t k) {
const std::size_t rk = crypto3::math::detail::bitreverse(k, logn);
if (k < rk)
Expand All @@ -100,7 +100,7 @@ namespace nil {

// Here we can parallelize on the both loops with 'k' and 'm', because for each value of k and m
// the ranges of array 'a' used do not intersect. Think of these 2 loops as 1.
wait_for_all(parallel_run_in_chunks<void>(
sycl_run_in_chunks(
m * count_k,
[&a, m, count_k, inc, &omega_cache](std::size_t begin, std::size_t end) {
size_t current_index = begin;
Expand All @@ -124,8 +124,7 @@ namespace nil {
return;
}
}
}, ThreadPool::PoolLevel::LOW
));
}));
}
}

Expand Down Expand Up @@ -209,5 +208,3 @@ namespace nil {
} // namespace fft
} // namespace crypto3
} // namespace nil

#endif // ALGEBRA_FFT_BASIC_RADIX2_DOMAIN_AUX_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ BOOST_AUTO_TEST_CASE(polynomial_dfs_multiplication_perf_test, *boost::unit_test:
std::cout << "Multiplication time: " << duration.count() << " microseconds." << std::endl;
}

BOOST_AUTO_TEST_CASE(polynomial_dfs_resize_perf_test, *boost::unit_test::disabled()) {
BOOST_AUTO_TEST_CASE(polynomial_dfs_resize_perf_test) {
std::vector<typename FieldType::value_type> values;
std::size_t size = 131072 * 16;
for (std::size_t i = 0; i < size; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//---------------------------------------------------------------------------//
// Copyright (c) 2024 Dmitrii Tabalin <[email protected]>
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//---------------------------------------------------------------------------//

#pragma once

#include <functional>
#include <hipsycl/sycl.hpp>

namespace nil {
namespace crypto3 {
template<class Function>
void sycl_run_in_chunks(
std::size_t elements_count,
Function func
) {
hipsycl::queue q;
std::size_t max_compute_units = q.get_device().get_info<hipsycl::info::device::max_compute_units>();
std::size_t workers_to_use =
std::max(static_cast<std::size_t>(1), std::min(elements_count, max_compute_units));
{
q.submit([&](hipsycl::handler& cgh) {
cgh.parallel_for<class ParallelRunKernel>(
hipsycl::range<1>(workers_to_use), [=](hipsycl::id<1> idx) {
const std::size_t i = idx[0];
const std::size_t chunk_size = elements_count / workers_to_use;
const std::size_t remainder = elements_count % workers_to_use;
const std::size_t begin = i * chunk_size + hipsycl::min(i, remainder);
const std::size_t end = begin + chunk_size + (i < remainder ? 1 : 0);
func(begin, end);
});
});
// The buffer destructor ensures synchronization
}
}

template<class Function>
void sycl_parallel_for(
std::size_t start,
std::size_t end,
Function func
) {
hipsycl::queue q;
{
q.submit([&](hipsycl::handler& cgh) {
cgh.parallel_for<class ParallelForKernel>(
hipsycl::range<1>(end - start), [=](hipsycl::id<1> idx) {
func(start + idx[0]);
});
});
// The buffer destructor ensures synchronization
}
}
} // namespace crypto3
} // namespace nil
6 changes: 5 additions & 1 deletion parallel-crypto3/parallel-crypto3.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
cmake,
boost,
gdb,
lldb,
cmake_modules,
crypto3,
opensycl,
enableDebugging,
enableDebug ? false,
runTests ? false,
Expand All @@ -18,7 +20,9 @@ in stdenv.mkDerivation {

src = lib.sourceByRegex ./. [ ".*" ];

nativeBuildInputs = [ cmake ninja pkg-config ] ++ (lib.optional (!stdenv.isDarwin) gdb);
nativeBuildInputs = [ cmake ninja pkg-config opensycl ] ++
(lib.optional (!stdenv.isDarwin) gdb) ++
(lib.optional (stdenv.isDarwin) lldb);

# enableDebugging will keep debug symbols in boost
propagatedBuildInputs = [ (if enableDebug then (enableDebugging boost) else boost) ];
Expand Down

0 comments on commit 3a667eb

Please sign in to comment.