Skip to content

Commit

Permalink
Going F A S T
Browse files Browse the repository at this point in the history
  • Loading branch information
Iluvmagick committed Feb 27, 2025
1 parent c4ed86b commit 51bbb3e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ namespace nil {
return !(rhs == *this && _d == rhs._d);
}
bool operator<(const polynomial_dfs& rhs) const {
return std::tie(val, _d) < std::tie(rhs.val, rhs._d);
return std::tie(_d, val) < std::tie(rhs._d, rhs.val);
}

allocator_type get_allocator() const BOOST_NOEXCEPT {
Expand Down Expand Up @@ -583,7 +583,6 @@ namespace nil {
this->resize(polynomial_s, domain, new_domain);
}


// Change the degree only here, after a possible resize, otherwise we have a polynomial
// with a high degree but small size, which sometimes segfaults.
this->_d += other._d;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ namespace nil {
using map_type = std::map<
std::shared_ptr<arena_node<VariableType>>,
size_t,
cmp_type
arena_node_ptr_comparator<arena_node<VariableType>>
>;
virtual assignment_type evaluate(
std::function<assignment_type(const VariableType&)> &evaluation_map
Expand Down Expand Up @@ -176,17 +176,14 @@ namespace nil {
using assignment_type = typename VariableType::assignment_type;
using base_type = arena_node<VariableType>;
using map_type = typename base_type::map_type;
std::optional<assignment_type> value;
assignment_type value;

arena_constant(const std::optional<assignment_type> &_value) :
arena_constant(const assignment_type &_value) :
value(_value) {}

arena_constant(const arena_constant &other) :
value(other.value) {}

arena_constant(const assignment_type &_value) :
value(_value) {}

virtual base_type::NodeType get_type() const override {
return base_type::NodeType::Constant;
}
Expand All @@ -196,7 +193,7 @@ namespace nil {
virtual assignment_type evaluate(
std::function<assignment_type(const VariableType&)> &evaluation_map
) override {
return value.value();
return value;
}

map_type children() const override {
Expand All @@ -207,24 +204,22 @@ namespace nil {
if (base_type::type_mismatch(other)) {
return false;
}
return value.value() == static_cast<const arena_constant<VariableType>&>(other).value.value();
return value == static_cast<const arena_constant<VariableType>&>(other).value;
}

bool operator<(const arena_node<VariableType> &other) const override {
if (base_type::type_mismatch(other)) {
return base_type::type_less(other);
}
return value.value() < static_cast<const arena_constant<VariableType>&>(other).value.value();
return value < static_cast<const arena_constant<VariableType>&>(other).value;
}

~arena_constant() = default;
};

template<typename VariableType>
std::shared_ptr<arena_node<VariableType>> count_map_insert(
std::map<std::shared_ptr<arena_node<VariableType>>,
size_t, arena_node_ptr_comparator<arena_node<VariableType>>
> &count_map,
typename arena_node<VariableType>::map_type &count_map,
std::shared_ptr<arena_node<VariableType>> node,
std::size_t count = 1
) {
Expand Down Expand Up @@ -371,7 +366,7 @@ namespace nil {
if (base_type::type_mismatch(other)) {
return base_type::type_less(other);
}
return base_type::children_map < static_cast<const arena_add<VariableType>&>(other).children_map;
return this->children_map < static_cast<const arena_add<VariableType>&>(other).children_map;
}

~arena_add() = default;
Expand Down Expand Up @@ -427,14 +422,14 @@ namespace nil {
if (base_type::type_mismatch(other)) {
return false;
}
return base_type::children_map == static_cast<const arena_mul<VariableType>&>(other).children_map;
return this->children_map == static_cast<const arena_mul<VariableType>&>(other).children_map;
}

bool operator<(const arena_node<VariableType> &other) const override {
if (base_type::type_mismatch(other)) {
return base_type::type_less(other);
}
return base_type::children_map < static_cast<const arena_mul<VariableType>&>(other).children_map;
return this->children_map < static_cast<const arena_mul<VariableType>&>(other).children_map;
}

~arena_mul() = default;
Expand Down Expand Up @@ -553,11 +548,7 @@ namespace nil {
struct arena_expression {
using assignment_type = typename VariableType::assignment_type;
using node_type = arena_node<VariableType>;
using map_type = std::map<
std::shared_ptr<node_type>,
size_t,
arena_node_ptr_comparator<node_type>
>;
using map_type = typename node_type::map_type;
// for translating math::expression to arena_expression
using math_expression_type = math::expression<VariableType>;
using term_type = math::term<VariableType>;
Expand Down Expand Up @@ -649,10 +640,8 @@ namespace nil {
}

arena_expression(const arena_expression& other) {
// Map to track which nodes have already been copied to avoid duplicating shared nodes
std::unordered_map<std::shared_ptr<node_type>, std::shared_ptr<node_type>> node_map;

// Copy all root nodes
for (const auto& root_node : other.root_nodes) {
auto copied_child = deep_copy_node(root_node->child, node_map);
root_nodes.push_back(make_arena_root<VariableType>(copied_child));
Expand All @@ -663,15 +652,15 @@ namespace nil {
if (this != &other) {
// Create a temporary copy and swap
arena_expression temp(other);
std::swap(ops, temp.ops);
std::swap(root_nodes, temp.root_nodes);
ops = std::move(temp.ops);
root_nodes = std::move(temp.root_nodes);
}
return *this;
}

void clear_cache() {
for (const auto& root_node : root_nodes) {
root_node->result = std::nullopt;
root_node->clear_cache();
}
for (const auto& [key, value] : ops) {
key->clear_cache();
Expand Down Expand Up @@ -704,14 +693,10 @@ namespace nil {
auto const_try = std::dynamic_pointer_cast<arena_constant<VariableType>>(node);
if (const_try != nullptr) {
return make_arena_node<NewVariableType>(
arena_constant<NewVariableType>(convert_const(*(const_try->value)))
arena_constant<NewVariableType>(convert_const(const_try->value))
);
}
using new_map_type = std::map<
std::shared_ptr<arena_node<NewVariableType>>,
size_t,
arena_node_ptr_comparator<arena_node<NewVariableType>>
>;
using new_map_type = arena_node<NewVariableType>::map_type;
auto add_try = std::dynamic_pointer_cast<arena_add<VariableType>>(node);
if (add_try != nullptr) {
new_map_type children;
Expand Down Expand Up @@ -753,7 +738,7 @@ namespace nil {
// Check if we've already copied this node
auto it = node_map.find(node);
if (it != node_map.end()) {
return it->second;
return insert_op(it->second);
}

// Copy the node based on its concrete type
Expand All @@ -762,50 +747,43 @@ namespace nil {
// Handle constants
auto const_try = std::dynamic_pointer_cast<arena_constant<VariableType>>(node);
if (const_try != nullptr) {
// Make a deep copy of the value
new_node = std::make_shared<arena_constant<VariableType>>(const_try->value);
}
// Handle variables
else if (auto var_try = std::dynamic_pointer_cast<arena_variable<VariableType>>(node)) {
auto var_try = std::dynamic_pointer_cast<arena_variable<VariableType>>(node);
if (var_try != nullptr) {
new_node = std::make_shared<arena_variable<VariableType>>(var_try->var);
}
// Handle addition operations
else if (auto add_try = std::dynamic_pointer_cast<arena_add<VariableType>>(node)) {
// Recursively copy all children
auto add_try = std::dynamic_pointer_cast<arena_add<VariableType>>(node);
if (add_try != nullptr) {
typename node_type::map_type new_children;
for (const auto& [child, count] : add_try->children()) {
auto copied_child = deep_copy_node(child, node_map);
count_map_insert(new_children, copied_child, count);
}
new_node = std::make_shared<arena_add<VariableType>>(new_children);
}
// Handle multiplication operations
else if (auto mul_try = std::dynamic_pointer_cast<arena_mul<VariableType>>(node)) {
// Recursively copy all children
auto mul_try = std::dynamic_pointer_cast<arena_mul<VariableType>>(node);
if (mul_try != nullptr) {
typename node_type::map_type new_children;
for (const auto& [child, count] : mul_try->children()) {
auto copied_child = deep_copy_node(child, node_map);
count_map_insert(new_children, copied_child, count);
}
new_node = std::make_shared<arena_mul<VariableType>>(new_children);
}
// Handle negation operations
else if (auto neg_try = std::dynamic_pointer_cast<arena_negate<VariableType>>(node)) {
auto neg_try = std::dynamic_pointer_cast<arena_negate<VariableType>>(node);
if (neg_try != nullptr) {
new_node = std::make_shared<arena_negate<VariableType>>(
deep_copy_node(neg_try->child, node_map)
);
}
// Handle root nodes
else if (auto root_try = std::dynamic_pointer_cast<arena_root<VariableType>>(node)) {
auto root_try = std::dynamic_pointer_cast<arena_root<VariableType>>(node);
if (root_try != nullptr) {
new_node = std::make_shared<arena_root<VariableType>>(
deep_copy_node(root_try->child, node_map)
);
}
else {
throw std::runtime_error("Unknown node type in deep_copy_node");
}

// Store the mapping and add the node to our ops map
node_map[node] = new_node;
insert_op(new_node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ namespace nil {
[&variable_values, &extended_domain_sizes, &result, &expressions, i, arena_expr]
(std::size_t begin, std::size_t end) {
auto arena_expr_copy = arena_expr;
arena_expr_copy.clear_cache();
for (std::size_t j = begin; j < end; ++j) {
std::function<value_type(const variable_type &)> eval_map =
[&variable_values, j](const variable_type &var) -> value_type {
Expand All @@ -226,7 +227,7 @@ namespace nil {
arena_expr_copy.clear_cache();
}
}, ThreadPool::PoolLevel::HIGH));

F[0] += result;
};
return F;
}
Expand Down

0 comments on commit 51bbb3e

Please sign in to comment.