Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/tensorwrapper/backends/cutensor/cuda_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
*/

#include "cuda_tensor.hpp"
#include <cassert>

#ifdef ENABLE_CUTENSOR
#include "eigen_tensor.cuh"
#include "cuda_tensor.cuh"
#endif

namespace tensorwrapper::backends::cutensor {
Expand All @@ -40,6 +41,30 @@ void CUDA_TENSOR::contraction_assignment(label_type this_label,
#endif
}

TPARAMS
CUDA_TENSOR::const_reference CUDA_TENSOR::get_elem(index_vector index) const {
assert(index.size() == rank());
std::size_t flat_index = 0;
std::size_t stride = 1;
for(std::size_t i = rank(); i-- > 0;) {
flat_index += index.at(i) * stride;
stride *= m_shape_.extent(i);
}
return m_data_[flat_index];
}

TPARAMS
void CUDA_TENSOR::set_elem(index_vector index, value_type new_value) {
assert(index.size() == rank());
std::size_t flat_index = 0;
std::size_t stride = 1;
for(std::size_t i = rank(); i-- > 0;) {
flat_index += index.at(i) * stride;
stride *= m_shape_.extent(i);
}
m_data_[flat_index] = new_value;
}

#undef CUDA_TENSOR
#undef TPARAMS

Expand Down
5 changes: 4 additions & 1 deletion src/tensorwrapper/backends/cutensor/cuda_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#ifdef ENABLE_CUTENSOR
#include "cuda_tensor.cuh"
#include "cutensor_traits.cuh"
#include <cassert>
#include <tensorwrapper/types/floating_point.hpp>
#include <unordered_map>
#include <vector>

Expand All @@ -42,7 +44,8 @@ using int64_vector_t = std::vector<int64_t>;
if(err != CUTENSOR_STATUS_SUCCESS) { \
printf("Error: %s\n", cutensorGetErrorString(err)); \
exit(-1); \
}
} \
};

// Convert a label into a vector of modes
template<typename LabelType>
Expand Down
6 changes: 6 additions & 0 deletions src/tensorwrapper/backends/cutensor/cuda_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class CUDATensor {
using const_shape_view = shape::SmoothView<const shape_type>;
using label_type = dsl::DummyIndices<std::string>;
using size_type = std::size_t;
using index_vector = std::vector<size_type>;
using const_reference = const value_type&;

CUDATensor(span_type data, const_shape_view shape) :
m_data_(data), m_shape_(shape) {}
Expand All @@ -57,6 +59,10 @@ class CUDATensor {

auto shape() const noexcept { return m_shape_; }

const_reference get_elem(index_vector index) const;

void set_elem(index_vector index, value_type new_value);

auto data() noexcept { return m_data_.data(); }

auto data() const noexcept { return m_data_.data(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ TEMPLATE_LIST_TEST_CASE("CUDATensor", "", supported_fp_types) {
REQUIRE(tensor4.shape() == tensor4_shape);
}

SECTION("get_elem()") {
REQUIRE(scalar.get_elem({}) == data[0]);
REQUIRE(vector.get_elem({5}) == data[5]);
REQUIRE(matrix.get_elem({1, 2}) == data[6]);
REQUIRE(tensor3.get_elem({1, 0, 3}) == data[11]);
REQUIRE(tensor4.get_elem({1, 1, 1, 0}) == data[14]);
}

SECTION("set_elem()") {
scalar.set_elem({}, static_cast<TestType>(42.0));
vector.set_elem({5}, static_cast<TestType>(42.0));
matrix.set_elem({1, 2}, static_cast<TestType>(42.0));
tensor3.set_elem({1, 0, 3}, static_cast<TestType>(42.0));
tensor4.set_elem({1, 1, 1, 0}, static_cast<TestType>(42.0));

REQUIRE(scalar.get_elem({}) == static_cast<TestType>(42.0));
REQUIRE(vector.get_elem({5}) == static_cast<TestType>(42.0));
REQUIRE(matrix.get_elem({1, 2}) == static_cast<TestType>(42.0));
REQUIRE(tensor3.get_elem({1, 0, 3}) == static_cast<TestType>(42.0));
REQUIRE(tensor4.get_elem({1, 1, 1, 0}) == static_cast<TestType>(42.0));
}

SECTION("data()") {
REQUIRE(scalar.data() == data.data());
REQUIRE(vector.data() == data.data());
Expand All @@ -87,8 +109,8 @@ TEMPLATE_LIST_TEST_CASE("CUDATensor", "", supported_fp_types) {
}

SECTION("contraction_assignment") {
#ifdef ENABLE_CUTESNSOR
testing::contraction_assignment<tensor_type>();
#ifdef ENABLE_CUTENSOR
testing::contraction_assignment_tests<tensor_type>();
#else
label_type label("");
REQUIRE_THROWS_AS(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void contraction_assignment_tests() {
VectorType vector(vector_data_span, vector_shape);
MatrixType matrix(matrix_data_span, matrix_shape);
Tensor3Type tensor3(tensor3_data_span, tensor3_shape);
Tensor4Type tensor4(tensor4_data, shape_type{2, 2, 2, 2});
Tensor4Type tensor4(tensor4_data_span, tensor4_shape);

SECTION("scalar,scalar->") {
label_type o("");
Expand Down
Loading