Skip to content

Commit bd319d8

Browse files
aoustryclaude
andauthored
Complementary devs for floor, ceil, min, and max operations (#186)
* fix: implement floor/ceil/max/min visitors in LinearExpressionBuilder These operators are valid in a linear expression when all operands are constant. Previously they raised unconditionally; now they evaluate to a constant result, and only raise ValueError when a non-constant (variable-containing) operand is present. https://claude.ai/code/session_01DL3tB7dnWvHsB8dNYBBNxZ * style: apply black formatting to test_linearization.py https://claude.ai/code/session_01DL3tB7dnWvHsB8dNYBBNxZ * feat: add visitTimeShiftExpr and visitTimeIndexExpr to expression visitor Adds support for parenthesized-expression time-shift/index syntax such as `(ceil(p/q))[t-1]`, which previously failed to parse because the ANTLR visitor lacked handlers for the timeShiftExpr and timeIndexExpr grammar rules. The new methods mirror the existing visitTimeShift and visitTimeIndex logic, using ctx.expr().accept(self) instead of _convert_identifier() to obtain the base expression. https://claude.ai/code/session_01So7yDQ3KiPj8oDa4xYnhUK * style: fix line length in visitTimeShiftExpr and visitTimeIndexExpr signatures https://claude.ai/code/session_01So7yDQ3KiPj8oDa4xYnhUK --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 158da98 commit bd319d8

File tree

4 files changed

+127
-4
lines changed

4 files changed

+127
-4
lines changed

src/gems/expression/parsing/parse_expression.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> ExpressionNode:
149149
return shifted_expr
150150
return shifted_expr.shift(time_shift)
151151

152+
def visitTimeShiftExpr(
153+
self, ctx: ExprParser.TimeShiftExprContext
154+
) -> ExpressionNode:
155+
shifted_expr = ctx.expr().accept(self) # type: ignore
156+
time_shift = ctx.shift().accept(self) # type: ignore
157+
if expressions_equal(time_shift, literal(0)):
158+
return shifted_expr
159+
return shifted_expr.shift(time_shift)
160+
161+
def visitTimeIndexExpr(
162+
self, ctx: ExprParser.TimeIndexExprContext
163+
) -> ExpressionNode:
164+
expr = ctx.expr(0).accept(self) # type: ignore
165+
eval_time = ctx.expr(1).accept(self) # type: ignore
166+
return expr.eval(eval_time)
167+
152168
def visitTimeSum(self, ctx: ExprParser.TimeSumContext) -> ExpressionNode:
153169
shifted_expr = ctx.expr().accept(self) # type: ignore
154170
from_shift = ctx.from_.accept(self) # type: ignore

src/gems/simulation/linearize.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#
1111
# This file is part of the Antares project.
1212

13+
import math
1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
1516
from typing import Any, Dict, List, Optional, Union
@@ -205,16 +206,36 @@ def literal(self, node: LiteralNode) -> LinearExpressionData:
205206
return LinearExpressionData([], node.value)
206207

207208
def floor(self, node: FloorNode) -> LinearExpressionData:
208-
raise ValueError("Linear expression cannot contain a floor operator.")
209+
operand = visit(node.operand, self)
210+
if operand.terms:
211+
raise ValueError(
212+
"Linear expression cannot contain a floor operator on a non-constant expression."
213+
)
214+
return LinearExpressionData([], math.floor(operand.constant))
209215

210216
def ceil(self, node: CeilNode) -> LinearExpressionData:
211-
raise ValueError("Linear expression cannot contain a ceil operator.")
217+
operand = visit(node.operand, self)
218+
if operand.terms:
219+
raise ValueError(
220+
"Linear expression cannot contain a ceil operator on a non-constant expression."
221+
)
222+
return LinearExpressionData([], math.ceil(operand.constant))
212223

213224
def maximum(self, node: MaxNode) -> LinearExpressionData:
214-
raise ValueError("Linear expression cannot contain a max operator.")
225+
operands = [visit(op, self) for op in node.operands]
226+
if any(op.terms for op in operands):
227+
raise ValueError(
228+
"Linear expression cannot contain a max operator on a non-constant expression."
229+
)
230+
return LinearExpressionData([], max(op.constant for op in operands))
215231

216232
def minimum(self, node: MinNode) -> LinearExpressionData:
217-
raise ValueError("Linear expression cannot contain a min operator.")
233+
operands = [visit(op, self) for op in node.operands]
234+
if any(op.terms for op in operands):
235+
raise ValueError(
236+
"Linear expression cannot contain a min operator on a non-constant expression."
237+
)
238+
return LinearExpressionData([], min(op.constant for op in operands))
218239

219240
def comparison(self, node: ComparisonNode) -> LinearExpressionData:
220241
raise ValueError("Linear expression cannot contain a comparison operator.")

tests/unittests/expressions/parsing/test_expression_parsing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,34 @@
167167
"min(a, b, c)",
168168
minimum(param("a"), param("b"), param("c")),
169169
),
170+
(
171+
{"x", "y"},
172+
{},
173+
"(x + y)[t-1]",
174+
(var("x") + var("y")).shift(-literal(1)),
175+
),
176+
(
177+
{},
178+
{"p", "q"},
179+
"(ceil(p/q))[t]",
180+
(param("p") / param("q")).ceil(),
181+
),
182+
(
183+
{},
184+
{"p_max_cluster", "p_max_unit"},
185+
"max(0, (ceil(p_max_cluster/p_max_unit))[t-1] - (ceil(p_max_cluster/p_max_unit)))",
186+
maximum(
187+
literal(0),
188+
(param("p_max_cluster") / param("p_max_unit")).ceil().shift(-literal(1))
189+
- (param("p_max_cluster") / param("p_max_unit")).ceil(),
190+
),
191+
),
192+
(
193+
{"x", "y"},
194+
{},
195+
"(x + y)[1]",
196+
(var("x") + var("y")).eval(literal(1)),
197+
),
170198
],
171199
)
172200
def test_parsing_visitor(

tests/unittests/expressions/visitor/test_linearization.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
TimeShift,
2222
comp_param,
2323
comp_var,
24+
maximum,
25+
minimum,
2426
problem_var,
2527
)
2628
from gems.expression.indexing import IndexingStructureProvider
@@ -160,3 +162,59 @@ def test_invalid_division() -> None:
160162
expression = literal(1) / x
161163
with pytest.raises(ValueError, match="constant"):
162164
linearize_expression(expression, 0, 0, params)
165+
166+
167+
def test_floor_of_constant() -> None:
168+
assert linearize_expression(
169+
literal(2.7).floor(), timestep=0, scenario=0
170+
) == constant(2)
171+
172+
173+
def test_ceil_of_constant() -> None:
174+
assert linearize_expression(
175+
literal(2.3).ceil(), timestep=0, scenario=0
176+
) == constant(3)
177+
178+
179+
def test_max_of_constants() -> None:
180+
assert linearize_expression(
181+
maximum(literal(3), literal(5)), timestep=0, scenario=0
182+
) == constant(5)
183+
184+
185+
def test_min_of_constants() -> None:
186+
assert linearize_expression(
187+
minimum(literal(3), literal(5)), timestep=0, scenario=0
188+
) == constant(3)
189+
190+
191+
def test_floor_of_variable_raises() -> None:
192+
x = problem_var(
193+
"c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex()
194+
)
195+
with pytest.raises(ValueError, match="non-constant"):
196+
linearize_expression(x.floor(), timestep=0, scenario=0)
197+
198+
199+
def test_ceil_of_variable_raises() -> None:
200+
x = problem_var(
201+
"c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex()
202+
)
203+
with pytest.raises(ValueError, match="non-constant"):
204+
linearize_expression(x.ceil(), timestep=0, scenario=0)
205+
206+
207+
def test_max_with_variable_raises() -> None:
208+
x = problem_var(
209+
"c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex()
210+
)
211+
with pytest.raises(ValueError, match="non-constant"):
212+
linearize_expression(maximum(x, literal(5)), timestep=0, scenario=0)
213+
214+
215+
def test_min_with_variable_raises() -> None:
216+
x = problem_var(
217+
"c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex()
218+
)
219+
with pytest.raises(ValueError, match="non-constant"):
220+
linearize_expression(minimum(x, literal(5)), timestep=0, scenario=0)

0 commit comments

Comments
 (0)