diff --git a/src/tensorwrapper/backends/cutensor/cuda_tensor.cpp b/src/tensorwrapper/backends/cutensor/cuda_tensor.cpp index ee3dca0f..ea7e9e8b 100644 --- a/src/tensorwrapper/backends/cutensor/cuda_tensor.cpp +++ b/src/tensorwrapper/backends/cutensor/cuda_tensor.cpp @@ -15,9 +15,10 @@ */ #include "cuda_tensor.hpp" +#include #ifdef ENABLE_CUTENSOR -#include "eigen_tensor.cuh" +#include "cuda_tensor.cuh" #endif namespace tensorwrapper::backends::cutensor { @@ -40,6 +41,30 @@ void CUDA_TENSOR::contraction_assignment(label_type this_label, #endif } +TPARAMS +CUDA_TENSOR::const_reference CUDA_TENSOR::get_elem(index_vector index) const { + assert(index.size() == rank()); + std::size_t flat_index = 0; + std::size_t stride = 1; + for(std::size_t i = rank(); i-- > 0;) { + flat_index += index.at(i) * stride; + stride *= m_shape_.extent(i); + } + return m_data_[flat_index]; +} + +TPARAMS +void CUDA_TENSOR::set_elem(index_vector index, value_type new_value) { + assert(index.size() == rank()); + std::size_t flat_index = 0; + std::size_t stride = 1; + for(std::size_t i = rank(); i-- > 0;) { + flat_index += index.at(i) * stride; + stride *= m_shape_.extent(i); + } + m_data_[flat_index] = new_value; +} + #undef CUDA_TENSOR #undef TPARAMS diff --git a/src/tensorwrapper/backends/cutensor/cuda_tensor.cu b/src/tensorwrapper/backends/cutensor/cuda_tensor.cu index b3e12329..5a39ddbd 100644 --- a/src/tensorwrapper/backends/cutensor/cuda_tensor.cu +++ b/src/tensorwrapper/backends/cutensor/cuda_tensor.cu @@ -16,6 +16,8 @@ #ifdef ENABLE_CUTENSOR #include "cuda_tensor.cuh" #include "cutensor_traits.cuh" +#include +#include #include #include @@ -42,7 +44,8 @@ using int64_vector_t = std::vector; if(err != CUTENSOR_STATUS_SUCCESS) { \ printf("Error: %s\n", cutensorGetErrorString(err)); \ exit(-1); \ - } + } \ + }; // Convert a label into a vector of modes template diff --git a/src/tensorwrapper/backends/cutensor/cuda_tensor.hpp b/src/tensorwrapper/backends/cutensor/cuda_tensor.hpp index d0ef8eb9..f71fe43e 100644 --- a/src/tensorwrapper/backends/cutensor/cuda_tensor.hpp +++ b/src/tensorwrapper/backends/cutensor/cuda_tensor.hpp @@ -43,6 +43,8 @@ class CUDATensor { using const_shape_view = shape::SmoothView; using label_type = dsl::DummyIndices; using size_type = std::size_t; + using index_vector = std::vector; + using const_reference = const value_type&; CUDATensor(span_type data, const_shape_view shape) : m_data_(data), m_shape_(shape) {} @@ -57,6 +59,10 @@ class CUDATensor { auto shape() const noexcept { return m_shape_; } + const_reference get_elem(index_vector index) const; + + void set_elem(index_vector index, value_type new_value); + auto data() noexcept { return m_data_.data(); } auto data() const noexcept { return m_data_.data(); } diff --git a/tests/cxx/unit_tests/tensorwrapper/backends/cutensor/cuda_tensor.cpp b/tests/cxx/unit_tests/tensorwrapper/backends/cutensor/cuda_tensor.cpp index a71352ca..fa573cfe 100644 --- a/tests/cxx/unit_tests/tensorwrapper/backends/cutensor/cuda_tensor.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/backends/cutensor/cuda_tensor.cpp @@ -70,6 +70,28 @@ TEMPLATE_LIST_TEST_CASE("CUDATensor", "", supported_fp_types) { REQUIRE(tensor4.shape() == tensor4_shape); } + SECTION("get_elem()") { + REQUIRE(scalar.get_elem({}) == data[0]); + REQUIRE(vector.get_elem({5}) == data[5]); + REQUIRE(matrix.get_elem({1, 2}) == data[6]); + REQUIRE(tensor3.get_elem({1, 0, 3}) == data[11]); + REQUIRE(tensor4.get_elem({1, 1, 1, 0}) == data[14]); + } + + SECTION("set_elem()") { + scalar.set_elem({}, static_cast(42.0)); + vector.set_elem({5}, static_cast(42.0)); + matrix.set_elem({1, 2}, static_cast(42.0)); + tensor3.set_elem({1, 0, 3}, static_cast(42.0)); + tensor4.set_elem({1, 1, 1, 0}, static_cast(42.0)); + + REQUIRE(scalar.get_elem({}) == static_cast(42.0)); + REQUIRE(vector.get_elem({5}) == static_cast(42.0)); + REQUIRE(matrix.get_elem({1, 2}) == static_cast(42.0)); + REQUIRE(tensor3.get_elem({1, 0, 3}) == static_cast(42.0)); + REQUIRE(tensor4.get_elem({1, 1, 1, 0}) == static_cast(42.0)); + } + SECTION("data()") { REQUIRE(scalar.data() == data.data()); REQUIRE(vector.data() == data.data()); @@ -87,8 +109,8 @@ TEMPLATE_LIST_TEST_CASE("CUDATensor", "", supported_fp_types) { } SECTION("contraction_assignment") { -#ifdef ENABLE_CUTESNSOR - testing::contraction_assignment(); +#ifdef ENABLE_CUTENSOR + testing::contraction_assignment_tests(); #else label_type label(""); REQUIRE_THROWS_AS( diff --git a/tests/cxx/unit_tests/tensorwrapper/backends/testing/contraction_assignment.hpp b/tests/cxx/unit_tests/tensorwrapper/backends/testing/contraction_assignment.hpp index b6837d0e..320275b2 100644 --- a/tests/cxx/unit_tests/tensorwrapper/backends/testing/contraction_assignment.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/backends/testing/contraction_assignment.hpp @@ -67,7 +67,7 @@ void contraction_assignment_tests() { VectorType vector(vector_data_span, vector_shape); MatrixType matrix(matrix_data_span, matrix_shape); Tensor3Type tensor3(tensor3_data_span, tensor3_shape); - Tensor4Type tensor4(tensor4_data, shape_type{2, 2, 2, 2}); + Tensor4Type tensor4(tensor4_data_span, tensor4_shape); SECTION("scalar,scalar->") { label_type o("");