Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tensorwrapper/tensorwrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tensorwrapper/symmetry/symmetry.hpp>
#include <tensorwrapper/tensor/tensor.hpp>
#include <tensorwrapper/types/types.hpp>
#include <tensorwrapper/utilities/utilities.hpp>

/** @brief Contains the components of the TensorWrapper library. */
namespace tensorwrapper {}
69 changes: 69 additions & 0 deletions include/tensorwrapper/utilities/floating_point_dispatch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2025 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <tensorwrapper/allocator/eigen.hpp>
#include <tensorwrapper/buffer/buffer_base.hpp>

namespace tensorwrapper::utilities {

/** @brief Wraps the logic needed to work out the floating point type of buffer.
*
* @tparam KernelType Type of a functor. The functor must define a function
* template called `run` that takes one explicit template
* type parameter (will be the floating point type of @p
* buffer) and @p buffer. `run` may take an arbitrary amount
* of additional arguments.
* @tparam BufferType The type of @p buffer. Must be derived from BufferBase.
* May contain cv or reference qualifiers.
* @tparam Args The types of any additional arguments which will be forwarded
* to @p kernel.
*
* @param[in] kernel The functor instance to call `run` on.
* @param[in] buffer The type of the elements in @p buffer will be used to
* dispatch.
* @param[in] args Any additional arguments to forward to @p kernel.
*
* @return Returns whatever @p kernel returns.
*
* @throw std::runtime_error if @p buffer is not derived from
*/
template<typename KernelType, typename BufferType, typename... Args>
decltype(auto) floating_point_dispatch(KernelType&& kernel, BufferType&& buffer,
Args&&... args) {
using buffer_clean = std::decay_t<BufferType>;
using buffer_base = buffer::BufferBase;
constexpr bool is_buffer = std::is_base_of_v<buffer_base, buffer_clean>;
static_assert(is_buffer);

using types::udouble;
using types::ufloat;

if(allocator::Eigen<float>::can_rebind(buffer)) {
return kernel.template run<float>(buffer, std::forward<Args>(args)...);
} else if(allocator::Eigen<double>::can_rebind(buffer)) {
return kernel.template run<double>(buffer, std::forward<Args>(args)...);
} else if(allocator::Eigen<ufloat>::can_rebind(buffer)) {
return kernel.template run<ufloat>(buffer, std::forward<Args>(args)...);
} else if(allocator::Eigen<udouble>::can_rebind(buffer)) {
return kernel.template run<udouble>(buffer,
std::forward<Args>(args)...);
} else {
throw std::runtime_error("Can't rebind buffer to Contiguous<>");
}
}

} // namespace tensorwrapper::utilities
21 changes: 21 additions & 0 deletions include/tensorwrapper/utilities/utilities.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2025 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <tensorwrapper/utilities/floating_point_dispatch.hpp>

/// Namespace for helper functions
namespace tensorwrapper::utilities {}
35 changes: 24 additions & 11 deletions src/tensorwrapper/operations/approximately_equal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,31 @@
#include <tensorwrapper/allocator/eigen.hpp>
#include <tensorwrapper/buffer/eigen.hpp>
#include <tensorwrapper/operations/approximately_equal.hpp>
#include <tensorwrapper/utilities/floating_point_dispatch.hpp>

namespace tensorwrapper::operations {
namespace {

struct Kernel {
template<typename FloatType>
bool run(const buffer::BufferBase& result, double tol) {
using allocator_type = allocator::Eigen<FloatType>;
const FloatType zero{0.0};
const FloatType ptol = static_cast<FloatType>(tol);
auto& buffer_down = allocator_type::rebind(result);

for(std::size_t i = 0; i < buffer_down.size(); ++i) {
auto diff = *(buffer_down.data() + i);
if(diff < zero) diff *= -1.0;
if(diff >= ptol) return false;
}
return true;
}
};

bool approximately_equal(const Tensor& lhs, const Tensor& rhs, double tol) {
using allocator_type = allocator::Eigen<double>;
} // namespace

bool approximately_equal(const Tensor& lhs, const Tensor& rhs, double tol) {
if(lhs.rank() != rhs.rank()) return false;

std::string index(lhs.rank() ? "i0" : "");
Expand All @@ -30,16 +50,9 @@ bool approximately_equal(const Tensor& lhs, const Tensor& rhs, double tol) {
Tensor result;
result(index) = lhs(index) - rhs(index);

if(!allocator_type::can_rebind(result.buffer()))
throw std::runtime_error("Buffer is not filled with doubles");

auto& buffer_down = allocator_type::rebind(result.buffer());
using tensorwrapper::utilities::floating_point_dispatch;

for(std::size_t i = 0; i < buffer_down.size(); ++i) {
auto diff = *(buffer_down.data() + i);
if(std::fabs(diff) >= tol) return false;
}
return true;
return floating_point_dispatch(Kernel{}, result.buffer(), tol);
}

} // namespace tensorwrapper::operations
114 changes: 73 additions & 41 deletions tests/cxx/unit_tests/tensorwrapper/operations/approximately_equal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,64 +19,96 @@
using namespace tensorwrapper;
using namespace operations;

TEST_CASE("approximately_equal") {
Tensor scalar(42.0);
Tensor vector{1.23, 2.34};
/* Notes on testing.
*
* - Because of how floating point conversions work, a difference of the
* tolerance may be equal, slightly less than, or slightly more than the
* tolerance converted to a different floating point type. We do not test for
* exact equality to the tolerance.
* - We can test for positive and negative differences by flipping the order of
* arguments.
*/

TEMPLATE_LIST_TEST_CASE("approximately_equal", "",
types::floating_point_types) {
auto pscalar = testing::eigen_scalar<TestType>();
pscalar->at() = 42.0;
auto pvector = testing::eigen_vector<TestType>(2);
pvector->at(0) = 1.23;
pvector->at(1) = 2.34;

auto pscalar2 = testing::eigen_scalar<TestType>();
pscalar2->at() = 42.0;
auto pvector2 = testing::eigen_vector<TestType>(2);
pvector2->at(0) = 1.23;
pvector2->at(1) = 2.34;

shape::Smooth s0{};
shape::Smooth s1{2};

Tensor scalar(s0, std::move(pscalar));
Tensor vector(s1, std::move(pvector));

SECTION("different ranks") {
REQUIRE_FALSE(approximately_equal(scalar, vector));
REQUIRE_FALSE(approximately_equal(vector, scalar));
}

SECTION("Same values") {
REQUIRE(approximately_equal(scalar, Tensor(42.0)));
REQUIRE(approximately_equal(vector, Tensor{1.23, 2.34}));
}
Tensor scalar2(s0, std::move(pscalar2));
Tensor vector2(s1, std::move(pvector2));

SECTION("Differ by default tolerance") {
double value = 1e-16;
REQUIRE_FALSE(approximately_equal(Tensor(0.0), Tensor(value)));
REQUIRE_FALSE(
approximately_equal(Tensor{1.23, 0.0}, Tensor{1.23, value}));
REQUIRE_FALSE(
approximately_equal(Tensor{0.0, 2.34}, Tensor{value, 2.34}));
REQUIRE(approximately_equal(scalar, scalar2));
REQUIRE(approximately_equal(scalar2, scalar));
REQUIRE(approximately_equal(vector, vector2));
REQUIRE(approximately_equal(vector2, vector));
}

SECTION("Differ by more than default tolerance") {
double value = 1e-16;
REQUIRE_FALSE(approximately_equal(scalar, Tensor(value)));
REQUIRE_FALSE(approximately_equal(vector, Tensor{1.23, value}));
REQUIRE_FALSE(approximately_equal(vector, Tensor{value, 2.34}));
double value = 1e-1;
pscalar2->at() = 42.0 + value;
pvector2->at(0) = 1.23 + value;
Tensor scalar2(s0, std::move(pscalar2));
Tensor vector2(s1, std::move(pvector2));
REQUIRE_FALSE(approximately_equal(scalar, scalar2));
REQUIRE_FALSE(approximately_equal(scalar2, scalar));
REQUIRE_FALSE(approximately_equal(vector, vector2));
REQUIRE_FALSE(approximately_equal(vector2, vector));
}

SECTION("Differ by less than default tolerance") {
double value = 1e-17;
REQUIRE(approximately_equal(Tensor(0.0), Tensor(value)));
REQUIRE(approximately_equal(Tensor{1.23, 0.0}, Tensor{1.23, value}));
REQUIRE(approximately_equal(Tensor{0.0, 2.34}, Tensor{value, 2.34}));
}

SECTION("Differ by provided tolerance") {
double value = 1e-1;
REQUIRE_FALSE(approximately_equal(Tensor(0.0), Tensor(value), value));
REQUIRE_FALSE(
approximately_equal(Tensor{1.23, 0.0}, Tensor{1.23, value}, value));
REQUIRE_FALSE(
approximately_equal(Tensor{0.0, 2.34}, Tensor{value, 2.34}, value));
double value = 1e-17;
pscalar2->at() = 42.0 + value;
pvector2->at(0) = 1.23 + value;
Tensor scalar2(s0, std::move(pscalar2));
Tensor vector2(s1, std::move(pvector2));
REQUIRE(approximately_equal(scalar, scalar2));
REQUIRE(approximately_equal(scalar2, scalar));
REQUIRE(approximately_equal(vector, vector2));
REQUIRE(approximately_equal(vector2, vector));
}

SECTION("Differ by more than provided tolerance") {
double value = 1e-1;
REQUIRE_FALSE(approximately_equal(scalar, Tensor(value), value));
REQUIRE_FALSE(approximately_equal(vector, Tensor{1.23, value}, value));
REQUIRE_FALSE(approximately_equal(vector, Tensor{value, 2.34}, value));
float value = 1e-1;
pscalar2->at() = 43.0;
pvector2->at(0) = 2.23;
Tensor scalar2(s0, std::move(pscalar2));
Tensor vector2(s1, std::move(pvector2));
REQUIRE_FALSE(approximately_equal(scalar, scalar2, value));
REQUIRE_FALSE(approximately_equal(scalar2, scalar, value));
REQUIRE_FALSE(approximately_equal(vector, vector2, value));
REQUIRE_FALSE(approximately_equal(vector2, vector, value));
}

SECTION("Differ by less than provided tolerance") {
double value = 1e-2;
REQUIRE(approximately_equal(Tensor(0.0), Tensor(value), 1e-1));
REQUIRE(
approximately_equal(Tensor{1.23, 0.0}, Tensor{1.23, value}, 1e-1));
REQUIRE(
approximately_equal(Tensor{0.0, 2.34}, Tensor{value, 2.34}, 1e-1));
double value = 1e-10;
pscalar2->at() = 42.0 + value;
pvector2->at(0) = 1.23 + value;
Tensor scalar2(s0, std::move(pscalar2));
Tensor vector2(s1, std::move(pvector2));
REQUIRE(approximately_equal(scalar, scalar2, 1e-1));
REQUIRE(approximately_equal(scalar2, scalar, 1e-1));
REQUIRE(approximately_equal(vector, vector2, 1e-1));
REQUIRE(approximately_equal(vector2, vector, 1e-1));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2025 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../testing/testing.hpp"

using namespace tensorwrapper;
using namespace tensorwrapper::utilities;

struct Kernel {
template<typename FloatType>
void run(buffer::BufferBase& buffer) {
auto corr = testing::eigen_matrix<FloatType>();
REQUIRE(corr->are_equal(buffer));
}

template<typename FloatType>
bool run(buffer::BufferBase& buffer, buffer::BufferBase& corr) {
return corr.are_equal(buffer);
}
};

TEMPLATE_LIST_TEST_CASE("floating_point_dispatch", "",
types::floating_point_types) {
Kernel kernel;
auto tensor = testing::eigen_matrix<TestType>();

SECTION("Single input, no return") {
floating_point_dispatch(kernel, *tensor);
}

SECTION("Two inputs and a return") {
REQUIRE(floating_point_dispatch(kernel, *tensor, *tensor));
}
}