diff --git a/src/gems/expression/parsing/parse_expression.py b/src/gems/expression/parsing/parse_expression.py index f001f9c2..47316bf8 100644 --- a/src/gems/expression/parsing/parse_expression.py +++ b/src/gems/expression/parsing/parse_expression.py @@ -149,6 +149,22 @@ def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> ExpressionNode: return shifted_expr return shifted_expr.shift(time_shift) + def visitTimeShiftExpr( + self, ctx: ExprParser.TimeShiftExprContext + ) -> ExpressionNode: + shifted_expr = ctx.expr().accept(self) # type: ignore + time_shift = ctx.shift().accept(self) # type: ignore + if expressions_equal(time_shift, literal(0)): + return shifted_expr + return shifted_expr.shift(time_shift) + + def visitTimeIndexExpr( + self, ctx: ExprParser.TimeIndexExprContext + ) -> ExpressionNode: + expr = ctx.expr(0).accept(self) # type: ignore + eval_time = ctx.expr(1).accept(self) # type: ignore + return expr.eval(eval_time) + def visitTimeSum(self, ctx: ExprParser.TimeSumContext) -> ExpressionNode: shifted_expr = ctx.expr().accept(self) # type: ignore from_shift = ctx.from_.accept(self) # type: ignore diff --git a/src/gems/simulation/linearize.py b/src/gems/simulation/linearize.py index 0d3c9e71..a360effb 100644 --- a/src/gems/simulation/linearize.py +++ b/src/gems/simulation/linearize.py @@ -10,6 +10,7 @@ # # This file is part of the Antares project. +import math from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union @@ -205,16 +206,36 @@ def literal(self, node: LiteralNode) -> LinearExpressionData: return LinearExpressionData([], node.value) def floor(self, node: FloorNode) -> LinearExpressionData: - raise ValueError("Linear expression cannot contain a floor operator.") + operand = visit(node.operand, self) + if operand.terms: + raise ValueError( + "Linear expression cannot contain a floor operator on a non-constant expression." + ) + return LinearExpressionData([], math.floor(operand.constant)) def ceil(self, node: CeilNode) -> LinearExpressionData: - raise ValueError("Linear expression cannot contain a ceil operator.") + operand = visit(node.operand, self) + if operand.terms: + raise ValueError( + "Linear expression cannot contain a ceil operator on a non-constant expression." + ) + return LinearExpressionData([], math.ceil(operand.constant)) def maximum(self, node: MaxNode) -> LinearExpressionData: - raise ValueError("Linear expression cannot contain a max operator.") + operands = [visit(op, self) for op in node.operands] + if any(op.terms for op in operands): + raise ValueError( + "Linear expression cannot contain a max operator on a non-constant expression." + ) + return LinearExpressionData([], max(op.constant for op in operands)) def minimum(self, node: MinNode) -> LinearExpressionData: - raise ValueError("Linear expression cannot contain a min operator.") + operands = [visit(op, self) for op in node.operands] + if any(op.terms for op in operands): + raise ValueError( + "Linear expression cannot contain a min operator on a non-constant expression." + ) + return LinearExpressionData([], min(op.constant for op in operands)) def comparison(self, node: ComparisonNode) -> LinearExpressionData: raise ValueError("Linear expression cannot contain a comparison operator.") diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index 7287bed3..ff052562 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -167,6 +167,34 @@ "min(a, b, c)", minimum(param("a"), param("b"), param("c")), ), + ( + {"x", "y"}, + {}, + "(x + y)[t-1]", + (var("x") + var("y")).shift(-literal(1)), + ), + ( + {}, + {"p", "q"}, + "(ceil(p/q))[t]", + (param("p") / param("q")).ceil(), + ), + ( + {}, + {"p_max_cluster", "p_max_unit"}, + "max(0, (ceil(p_max_cluster/p_max_unit))[t-1] - (ceil(p_max_cluster/p_max_unit)))", + maximum( + literal(0), + (param("p_max_cluster") / param("p_max_unit")).ceil().shift(-literal(1)) + - (param("p_max_cluster") / param("p_max_unit")).ceil(), + ), + ), + ( + {"x", "y"}, + {}, + "(x + y)[1]", + (var("x") + var("y")).eval(literal(1)), + ), ], ) def test_parsing_visitor( diff --git a/tests/unittests/expressions/visitor/test_linearization.py b/tests/unittests/expressions/visitor/test_linearization.py index 64a033a4..d64741e1 100644 --- a/tests/unittests/expressions/visitor/test_linearization.py +++ b/tests/unittests/expressions/visitor/test_linearization.py @@ -21,6 +21,8 @@ TimeShift, comp_param, comp_var, + maximum, + minimum, problem_var, ) from gems.expression.indexing import IndexingStructureProvider @@ -160,3 +162,59 @@ def test_invalid_division() -> None: expression = literal(1) / x with pytest.raises(ValueError, match="constant"): linearize_expression(expression, 0, 0, params) + + +def test_floor_of_constant() -> None: + assert linearize_expression( + literal(2.7).floor(), timestep=0, scenario=0 + ) == constant(2) + + +def test_ceil_of_constant() -> None: + assert linearize_expression( + literal(2.3).ceil(), timestep=0, scenario=0 + ) == constant(3) + + +def test_max_of_constants() -> None: + assert linearize_expression( + maximum(literal(3), literal(5)), timestep=0, scenario=0 + ) == constant(5) + + +def test_min_of_constants() -> None: + assert linearize_expression( + minimum(literal(3), literal(5)), timestep=0, scenario=0 + ) == constant(3) + + +def test_floor_of_variable_raises() -> None: + x = problem_var( + "c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex() + ) + with pytest.raises(ValueError, match="non-constant"): + linearize_expression(x.floor(), timestep=0, scenario=0) + + +def test_ceil_of_variable_raises() -> None: + x = problem_var( + "c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex() + ) + with pytest.raises(ValueError, match="non-constant"): + linearize_expression(x.ceil(), timestep=0, scenario=0) + + +def test_max_with_variable_raises() -> None: + x = problem_var( + "c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex() + ) + with pytest.raises(ValueError, match="non-constant"): + linearize_expression(maximum(x, literal(5)), timestep=0, scenario=0) + + +def test_min_with_variable_raises() -> None: + x = problem_var( + "c", "x", time_index=TimeShift(0), scenario_index=CurrentScenarioIndex() + ) + with pytest.raises(ValueError, match="non-constant"): + linearize_expression(minimum(x, literal(5)), timestep=0, scenario=0)