Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gems/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import List, cast

from .expression import (
AdditionNode,
AllTimeSumNode,
CeilNode,
ComparisonNode,
Expand Down Expand Up @@ -44,6 +45,9 @@ class CopyVisitor(ExpressionVisitorOperations[ExpressionNode]):
Simply copies the whole AST.
"""

def addition(self, node: AdditionNode) -> ExpressionNode:
return AdditionNode([visit(o, self) for o in node.operands])

def literal(self, node: LiteralNode) -> ExpressionNode:
return LiteralNode(node.value)

Expand Down
40 changes: 40 additions & 0 deletions tests/unittests/expressions/visitor/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# This file is part of the Antares project.


import time

from gems.expression import (
AdditionNode,
DivisionNode,
Expand All @@ -29,6 +31,44 @@
)


def test_copy_large_addition_is_linear() -> None:
"""
Copying an AdditionNode with T operands must be O(T), not O(T²).

Before the fix, CopyVisitor inherited the default addition() from
ExpressionVisitorOperations which accumulated results with `res = res + o`.
Each call to __add__ flattens the AdditionNode by copying the accumulated
operand list, giving 1+2+...+(T-1) = O(T²) list copies in total.

The fix overrides addition() in CopyVisitor to build AdditionNode directly
from a list comprehension, reducing the cost to O(T).

We verify linearity by checking that the time ratio between T=10_000 and
T=1_000 stays below 20 (linear ≈ 10, quadratic ≈ 100).
"""
small_n = 1_000
large_n = 10_000

small_node = AdditionNode([VariableNode(f"x{i}") for i in range(small_n)])
large_node = AdditionNode([VariableNode(f"x{i}") for i in range(large_n)])

t0 = time.perf_counter()
copy_expression(small_node)
small_time = time.perf_counter() - t0

t0 = time.perf_counter()
copy_expression(large_node)
large_time = time.perf_counter() - t0

ratio = large_time / small_time
assert ratio < 20, (
f"copy_expression scaling looks super-linear: "
f"T={small_n} took {small_time:.4f}s, "
f"T={large_n} took {large_time:.4f}s, "
f"ratio={ratio:.1f} (expected <20 for O(T))"
)


def test_copy_ast() -> None:
ast = AllTimeSumNode(
DivisionNode(
Expand Down
Loading