diff --git a/.gitignore b/.gitignore index 1edab6c2..59318f35 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ bin outputs build src/gems.egg-info/ +src/gemspy.egg-info/ diff --git a/grammar/Expr.g4 b/grammar/Expr.g4 index 73e4d0ba..92928f38 100644 --- a/grammar/Expr.g4 +++ b/grammar/Expr.g4 @@ -29,6 +29,7 @@ expr | 'sum_connections' '(' portFieldExpr ')' # portFieldSum | 'sum' '(' from=shift '..' to=shift ',' expr ')' # timeSum | IDENTIFIER '(' expr ')' # function + | IDENTIFIER '(' expr ',' expr ')' # binaryFunction | IDENTIFIER '[' shift ']' # timeShift | IDENTIFIER '[' expr ']' # timeIndex | '(' expr ')' '[' shift ']' # timeShiftExpr diff --git a/src/gems/expression/__init__.py b/src/gems/expression/__init__.py index 3d949b6e..1eafc790 100644 --- a/src/gems/expression/__init__.py +++ b/src/gems/expression/__init__.py @@ -20,16 +20,22 @@ ) from .expression import ( AdditionNode, + CeilNode, Comparator, ComparisonNode, DivisionNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, MultiplicationNode, NegationNode, ParameterNode, VariableNode, literal, + maximum, + minimum, param, sum_expressions, var, diff --git a/src/gems/expression/copy.py b/src/gems/expression/copy.py index eea34d74..108b98ac 100644 --- a/src/gems/expression/copy.py +++ b/src/gems/expression/copy.py @@ -15,11 +15,15 @@ from .expression import ( AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, ParameterNode, PortFieldAggregatorNode, PortFieldNode, @@ -95,6 +99,18 @@ def port_field(self, node: PortFieldNode) -> ExpressionNode: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: return PortFieldAggregatorNode(visit(node.operand, self), node.aggregator) + def floor(self, node: FloorNode) -> ExpressionNode: + return FloorNode(visit(node.operand, self)) + + def ceil(self, node: CeilNode) -> ExpressionNode: + return CeilNode(visit(node.operand, self)) + + def maximum(self, node: MaxNode) -> ExpressionNode: + return MaxNode(visit(node.left, self), visit(node.right, self)) + + def minimum(self, node: MinNode) -> ExpressionNode: + return MinNode(visit(node.left, self), visit(node.right, self)) + def copy_expression(expression: ExpressionNode) -> ExpressionNode: return visit(expression, CopyVisitor()) diff --git a/src/gems/expression/degree.py b/src/gems/expression/degree.py index d9051818..0be2c5fe 100644 --- a/src/gems/expression/degree.py +++ b/src/gems/expression/degree.py @@ -10,11 +10,17 @@ # # This file is part of the Antares project. +import math + import gems.expression.scenario_operator from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -39,77 +45,95 @@ from .visitor import ExpressionVisitor, T, visit -class ExpressionDegreeVisitor(ExpressionVisitor[int]): +class ExpressionDegreeVisitor(ExpressionVisitor[int | float]): """ Computes degree of expression with respect to variables. """ - def literal(self, node: LiteralNode) -> int: + def literal(self, node: LiteralNode) -> int | float: return 0 - def negation(self, node: NegationNode) -> int: + def negation(self, node: NegationNode) -> int | float: return visit(node.operand, self) # TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div - def addition(self, node: AdditionNode) -> int: + def addition(self, node: AdditionNode) -> int | float: degrees = [visit(o, self) for o in node.operands] return max(degrees) - def multiplication(self, node: MultiplicationNode) -> int: + def multiplication(self, node: MultiplicationNode) -> int | float: return visit(node.left, self) + visit(node.right, self) - def division(self, node: DivisionNode) -> int: + def division(self, node: DivisionNode) -> int | float: right_degree = visit(node.right, self) if right_degree != 0: raise ValueError("Degree computation not implemented for divisions.") return visit(node.left, self) - def comparison(self, node: ComparisonNode) -> int: + def comparison(self, node: ComparisonNode) -> int | float: return max(visit(node.left, self), visit(node.right, self)) - def variable(self, node: VariableNode) -> int: + def variable(self, node: VariableNode) -> int | float: return 1 - def parameter(self, node: ParameterNode) -> int: + def parameter(self, node: ParameterNode) -> int | float: return 0 - def comp_variable(self, node: ComponentVariableNode) -> int: + def comp_variable(self, node: ComponentVariableNode) -> int | float: return 1 - def comp_parameter(self, node: ComponentParameterNode) -> int: + def comp_parameter(self, node: ComponentParameterNode) -> int | float: return 0 - def pb_variable(self, node: ProblemVariableNode) -> int: + def pb_variable(self, node: ProblemVariableNode) -> int | float: return 1 - def pb_parameter(self, node: ProblemParameterNode) -> int: + def pb_parameter(self, node: ProblemParameterNode) -> int | float: return 0 - def time_shift(self, node: TimeShiftNode) -> int: + def time_shift(self, node: TimeShiftNode) -> int | float: return visit(node.operand, self) - def time_eval(self, node: TimeEvalNode) -> int: + def time_eval(self, node: TimeEvalNode) -> int | float: return visit(node.operand, self) - def time_sum(self, node: TimeSumNode) -> int: + def time_sum(self, node: TimeSumNode) -> int | float: return visit(node.operand, self) - def all_time_sum(self, node: AllTimeSumNode) -> int: + def all_time_sum(self, node: AllTimeSumNode) -> int | float: return visit(node.operand, self) - def scenario_operator(self, node: ScenarioOperatorNode) -> int: + def scenario_operator(self, node: ScenarioOperatorNode) -> int | float: scenario_operator_cls = getattr(gems.expression.scenario_operator, node.name) # TODO: Carefully check if this formula is correct return scenario_operator_cls.degree() * visit(node.operand, self) - def port_field(self, node: PortFieldNode) -> int: + def port_field(self, node: PortFieldNode) -> int | float: return 1 - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int: + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int | float: return visit(node.operand, self) + def floor(self, node: FloorNode) -> int | float: + d = visit(node.operand, self) + return 0 if d == 0 else math.inf + + def ceil(self, node: CeilNode) -> int | float: + d = visit(node.operand, self) + return 0 if d == 0 else math.inf + + def maximum(self, node: MaxNode) -> int | float: + d_l = visit(node.left, self) + d_r = visit(node.right, self) + return 0 if d_l == 0 and d_r == 0 else math.inf + + def minimum(self, node: MinNode) -> int | float: + d_l = visit(node.left, self) + d_r = visit(node.right, self) + return 0 if d_l == 0 and d_r == 0 else math.inf + -def compute_degree(expression: ExpressionNode) -> int: +def compute_degree(expression: ExpressionNode) -> int | float: return visit(expression, ExpressionDegreeVisitor()) diff --git a/src/gems/expression/equality.py b/src/gems/expression/equality.py index d5a543ff..0517269e 100644 --- a/src/gems/expression/equality.py +++ b/src/gems/expression/equality.py @@ -28,8 +28,12 @@ from gems.expression.expression import ( AllTimeSumNode, BinaryOperatorNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -111,6 +115,14 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: right, PortFieldAggregatorNode ): return self.port_field_aggregator(left, right) + if isinstance(left, FloorNode) and isinstance(right, FloorNode): + return self.floor(left, right) + if isinstance(left, CeilNode) and isinstance(right, CeilNode): + return self.ceil(left, right) + if isinstance(left, MaxNode) and isinstance(right, MaxNode): + return self.maximum(left, right) + if isinstance(left, MinNode) and isinstance(right, MinNode): + return self.minimum(left, right) raise NotImplementedError(f"Equality not implemented for {left.__class__}") def literal(self, left: LiteralNode, right: LiteralNode) -> bool: @@ -215,6 +227,18 @@ def port_field_aggregator( left.operand, right.operand ) + def floor(self, left: FloorNode, right: FloorNode) -> bool: + return self.visit(left.operand, right.operand) + + def ceil(self, left: CeilNode, right: CeilNode) -> bool: + return self.visit(left.operand, right.operand) + + def maximum(self, left: MaxNode, right: MaxNode) -> bool: + return self._visit_operands(left, right) + + def minimum(self, left: MinNode, right: MinNode) -> bool: + return self._visit_operands(left, right) + def expressions_equal( left: ExpressionNode, right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0 diff --git a/src/gems/expression/evaluate.py b/src/gems/expression/evaluate.py index 398db7da..4d3c3770 100644 --- a/src/gems/expression/evaluate.py +++ b/src/gems/expression/evaluate.py @@ -10,14 +10,19 @@ # # This file is part of the Antares project. +import math from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -139,6 +144,18 @@ def port_field(self, node: PortFieldNode) -> float: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: raise NotImplementedError() + def floor(self, node: FloorNode) -> float: + return float(math.floor(visit(node.operand, self))) + + def ceil(self, node: CeilNode) -> float: + return float(math.ceil(visit(node.operand, self))) + + def maximum(self, node: MaxNode) -> float: + return max(visit(node.left, self), visit(node.right, self)) + + def minimum(self, node: MinNode) -> float: + return min(visit(node.left, self), visit(node.right, self)) + def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float: return visit(expression, EvaluationVisitor(value_provider)) diff --git a/src/gems/expression/expression.py b/src/gems/expression/expression.py index 44bfcfc9..94405f0c 100644 --- a/src/gems/expression/expression.py +++ b/src/gems/expression/expression.py @@ -13,6 +13,7 @@ """ Defines the model for generic expressions. """ + import enum import inspect from dataclasses import dataclass @@ -117,6 +118,12 @@ def shift(self, shift: AnyExpression) -> "ExpressionNode": def eval(self, time: AnyExpression) -> "ExpressionNode": return TimeEvalNode(self, _wrap_in_node(time)) + def floor(self) -> "ExpressionNode": + return FloorNode(self) + + def ceil(self) -> "ExpressionNode": + return CeilNode(self) + def expec(self) -> "ExpressionNode": return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) @@ -368,6 +375,16 @@ class NegationNode(UnaryOperatorNode): pass +@dataclass(frozen=True, eq=False) +class FloorNode(UnaryOperatorNode): + pass + + +@dataclass(frozen=True, eq=False) +class CeilNode(UnaryOperatorNode): + pass + + @dataclass(frozen=True, eq=False) class BinaryOperatorNode(ExpressionNode): left: ExpressionNode @@ -400,6 +417,24 @@ class DivisionNode(BinaryOperatorNode): pass +@dataclass(frozen=True, eq=False) +class MaxNode(BinaryOperatorNode): + pass + + +@dataclass(frozen=True, eq=False) +class MinNode(BinaryOperatorNode): + pass + + +def maximum(left: "ExpressionNode", right: "ExpressionNode") -> "MaxNode": + return MaxNode(left, right) + + +def minimum(left: "ExpressionNode", right: "ExpressionNode") -> "MinNode": + return MinNode(left, right) + + @dataclass(frozen=True, eq=False) class TimeShiftNode(UnaryOperatorNode): time_shift: ExpressionNode diff --git a/src/gems/expression/indexing.py b/src/gems/expression/indexing.py index 022aa936..7d7f1f0b 100644 --- a/src/gems/expression/indexing.py +++ b/src/gems/expression/indexing.py @@ -19,12 +19,16 @@ from .expression import ( AdditionNode, AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, DivisionNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, MultiplicationNode, NegationNode, ParameterNode, @@ -159,6 +163,18 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStruct "Port fields aggregators must be resolved before computing indexing structure." ) + def floor(self, node: FloorNode) -> IndexingStructure: + return visit(node.operand, self) + + def ceil(self, node: CeilNode) -> IndexingStructure: + return visit(node.operand, self) + + def maximum(self, node: MaxNode) -> IndexingStructure: + return self._combine([node.left, node.right]) + + def minimum(self, node: MinNode) -> IndexingStructure: + return self._combine([node.left, node.right]) + def compute_indexation( expression: ExpressionNode, provider: IndexingStructureProvider diff --git a/src/gems/expression/parsing/antlr/Expr.interp b/src/gems/expression/parsing/antlr/Expr.interp index 94f1370f..4cd49648 100644 --- a/src/gems/expression/parsing/antlr/Expr.interp +++ b/src/gems/expression/parsing/antlr/Expr.interp @@ -51,4 +51,4 @@ right_expr atn: -[4, 1, 18, 140, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 79, 8, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 90, 8, 2, 10, 2, 12, 2, 93, 9, 2, 1, 3, 1, 3, 3, 3, 97, 8, 3, 1, 4, 1, 4, 3, 4, 101, 8, 4, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 3, 5, 111, 8, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 5, 5, 119, 8, 5, 10, 5, 12, 5, 122, 9, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 130, 8, 6, 1, 6, 1, 6, 1, 6, 5, 6, 135, 8, 6, 10, 6, 12, 6, 138, 9, 6, 1, 6, 0, 3, 4, 10, 12, 7, 0, 2, 4, 6, 8, 10, 12, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 153, 0, 14, 1, 0, 0, 0, 2, 18, 1, 0, 0, 0, 4, 78, 1, 0, 0, 0, 6, 96, 1, 0, 0, 0, 8, 98, 1, 0, 0, 0, 10, 110, 1, 0, 0, 0, 12, 129, 1, 0, 0, 0, 14, 15, 5, 16, 0, 0, 15, 16, 5, 1, 0, 0, 16, 17, 5, 16, 0, 0, 17, 1, 1, 0, 0, 0, 18, 19, 3, 4, 2, 0, 19, 20, 5, 0, 0, 1, 20, 3, 1, 0, 0, 0, 21, 22, 6, 2, -1, 0, 22, 79, 3, 6, 3, 0, 23, 79, 3, 0, 0, 0, 24, 25, 5, 2, 0, 0, 25, 79, 3, 4, 2, 13, 26, 27, 5, 3, 0, 0, 27, 28, 3, 4, 2, 0, 28, 29, 5, 4, 0, 0, 29, 79, 1, 0, 0, 0, 30, 31, 5, 8, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 4, 2, 0, 33, 34, 5, 4, 0, 0, 34, 79, 1, 0, 0, 0, 35, 36, 5, 9, 0, 0, 36, 37, 5, 3, 0, 0, 37, 38, 3, 0, 0, 0, 38, 39, 5, 4, 0, 0, 39, 79, 1, 0, 0, 0, 40, 41, 5, 8, 0, 0, 41, 42, 5, 3, 0, 0, 42, 43, 3, 8, 4, 0, 43, 44, 5, 10, 0, 0, 44, 45, 3, 8, 4, 0, 45, 46, 5, 11, 0, 0, 46, 47, 3, 4, 2, 0, 47, 48, 5, 4, 0, 0, 48, 79, 1, 0, 0, 0, 49, 50, 5, 16, 0, 0, 50, 51, 5, 3, 0, 0, 51, 52, 3, 4, 2, 0, 52, 53, 5, 4, 0, 0, 53, 79, 1, 0, 0, 0, 54, 55, 5, 16, 0, 0, 55, 56, 5, 12, 0, 0, 56, 57, 3, 8, 4, 0, 57, 58, 5, 13, 0, 0, 58, 79, 1, 0, 0, 0, 59, 60, 5, 16, 0, 0, 60, 61, 5, 12, 0, 0, 61, 62, 3, 4, 2, 0, 62, 63, 5, 13, 0, 0, 63, 79, 1, 0, 0, 0, 64, 65, 5, 3, 0, 0, 65, 66, 3, 4, 2, 0, 66, 67, 5, 4, 0, 0, 67, 68, 5, 12, 0, 0, 68, 69, 3, 8, 4, 0, 69, 70, 5, 13, 0, 0, 70, 79, 1, 0, 0, 0, 71, 72, 5, 3, 0, 0, 72, 73, 3, 4, 2, 0, 73, 74, 5, 4, 0, 0, 74, 75, 5, 12, 0, 0, 75, 76, 3, 4, 2, 0, 76, 77, 5, 13, 0, 0, 77, 79, 1, 0, 0, 0, 78, 21, 1, 0, 0, 0, 78, 23, 1, 0, 0, 0, 78, 24, 1, 0, 0, 0, 78, 26, 1, 0, 0, 0, 78, 30, 1, 0, 0, 0, 78, 35, 1, 0, 0, 0, 78, 40, 1, 0, 0, 0, 78, 49, 1, 0, 0, 0, 78, 54, 1, 0, 0, 0, 78, 59, 1, 0, 0, 0, 78, 64, 1, 0, 0, 0, 78, 71, 1, 0, 0, 0, 79, 91, 1, 0, 0, 0, 80, 81, 10, 11, 0, 0, 81, 82, 7, 0, 0, 0, 82, 90, 3, 4, 2, 12, 83, 84, 10, 10, 0, 0, 84, 85, 7, 1, 0, 0, 85, 90, 3, 4, 2, 11, 86, 87, 10, 9, 0, 0, 87, 88, 5, 17, 0, 0, 88, 90, 3, 4, 2, 10, 89, 80, 1, 0, 0, 0, 89, 83, 1, 0, 0, 0, 89, 86, 1, 0, 0, 0, 90, 93, 1, 0, 0, 0, 91, 89, 1, 0, 0, 0, 91, 92, 1, 0, 0, 0, 92, 5, 1, 0, 0, 0, 93, 91, 1, 0, 0, 0, 94, 97, 5, 14, 0, 0, 95, 97, 5, 16, 0, 0, 96, 94, 1, 0, 0, 0, 96, 95, 1, 0, 0, 0, 97, 7, 1, 0, 0, 0, 98, 100, 5, 15, 0, 0, 99, 101, 3, 10, 5, 0, 100, 99, 1, 0, 0, 0, 100, 101, 1, 0, 0, 0, 101, 9, 1, 0, 0, 0, 102, 103, 6, 5, -1, 0, 103, 104, 7, 1, 0, 0, 104, 111, 3, 6, 3, 0, 105, 106, 7, 1, 0, 0, 106, 107, 5, 3, 0, 0, 107, 108, 3, 4, 2, 0, 108, 109, 5, 4, 0, 0, 109, 111, 1, 0, 0, 0, 110, 102, 1, 0, 0, 0, 110, 105, 1, 0, 0, 0, 111, 120, 1, 0, 0, 0, 112, 113, 10, 4, 0, 0, 113, 114, 7, 0, 0, 0, 114, 119, 3, 12, 6, 0, 115, 116, 10, 3, 0, 0, 116, 117, 7, 1, 0, 0, 117, 119, 3, 12, 6, 0, 118, 112, 1, 0, 0, 0, 118, 115, 1, 0, 0, 0, 119, 122, 1, 0, 0, 0, 120, 118, 1, 0, 0, 0, 120, 121, 1, 0, 0, 0, 121, 11, 1, 0, 0, 0, 122, 120, 1, 0, 0, 0, 123, 124, 6, 6, -1, 0, 124, 125, 5, 3, 0, 0, 125, 126, 3, 4, 2, 0, 126, 127, 5, 4, 0, 0, 127, 130, 1, 0, 0, 0, 128, 130, 3, 6, 3, 0, 129, 123, 1, 0, 0, 0, 129, 128, 1, 0, 0, 0, 130, 136, 1, 0, 0, 0, 131, 132, 10, 3, 0, 0, 132, 133, 7, 0, 0, 0, 133, 135, 3, 12, 6, 4, 134, 131, 1, 0, 0, 0, 135, 138, 1, 0, 0, 0, 136, 134, 1, 0, 0, 0, 136, 137, 1, 0, 0, 0, 137, 13, 1, 0, 0, 0, 138, 136, 1, 0, 0, 0, 10, 78, 89, 91, 96, 100, 110, 118, 120, 129, 136] \ No newline at end of file +[4, 1, 18, 147, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 86, 8, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 97, 8, 2, 10, 2, 12, 2, 100, 9, 2, 1, 3, 1, 3, 3, 3, 104, 8, 3, 1, 4, 1, 4, 3, 4, 108, 8, 4, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 3, 5, 118, 8, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 5, 5, 126, 8, 5, 10, 5, 12, 5, 129, 9, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 137, 8, 6, 1, 6, 1, 6, 1, 6, 5, 6, 142, 8, 6, 10, 6, 12, 6, 145, 9, 6, 1, 6, 0, 3, 4, 10, 12, 7, 0, 2, 4, 6, 8, 10, 12, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 161, 0, 14, 1, 0, 0, 0, 2, 18, 1, 0, 0, 0, 4, 85, 1, 0, 0, 0, 6, 103, 1, 0, 0, 0, 8, 105, 1, 0, 0, 0, 10, 117, 1, 0, 0, 0, 12, 136, 1, 0, 0, 0, 14, 15, 5, 16, 0, 0, 15, 16, 5, 1, 0, 0, 16, 17, 5, 16, 0, 0, 17, 1, 1, 0, 0, 0, 18, 19, 3, 4, 2, 0, 19, 20, 5, 0, 0, 1, 20, 3, 1, 0, 0, 0, 21, 22, 6, 2, -1, 0, 22, 86, 3, 6, 3, 0, 23, 86, 3, 0, 0, 0, 24, 25, 5, 2, 0, 0, 25, 86, 3, 4, 2, 14, 26, 27, 5, 3, 0, 0, 27, 28, 3, 4, 2, 0, 28, 29, 5, 4, 0, 0, 29, 86, 1, 0, 0, 0, 30, 31, 5, 8, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 4, 2, 0, 33, 34, 5, 4, 0, 0, 34, 86, 1, 0, 0, 0, 35, 36, 5, 9, 0, 0, 36, 37, 5, 3, 0, 0, 37, 38, 3, 0, 0, 0, 38, 39, 5, 4, 0, 0, 39, 86, 1, 0, 0, 0, 40, 41, 5, 8, 0, 0, 41, 42, 5, 3, 0, 0, 42, 43, 3, 8, 4, 0, 43, 44, 5, 10, 0, 0, 44, 45, 3, 8, 4, 0, 45, 46, 5, 11, 0, 0, 46, 47, 3, 4, 2, 0, 47, 48, 5, 4, 0, 0, 48, 86, 1, 0, 0, 0, 49, 50, 5, 16, 0, 0, 50, 51, 5, 3, 0, 0, 51, 52, 3, 4, 2, 0, 52, 53, 5, 4, 0, 0, 53, 86, 1, 0, 0, 0, 54, 55, 5, 16, 0, 0, 55, 56, 5, 3, 0, 0, 56, 57, 3, 4, 2, 0, 57, 58, 5, 11, 0, 0, 58, 59, 3, 4, 2, 0, 59, 60, 5, 4, 0, 0, 60, 86, 1, 0, 0, 0, 61, 62, 5, 16, 0, 0, 62, 63, 5, 12, 0, 0, 63, 64, 3, 8, 4, 0, 64, 65, 5, 13, 0, 0, 65, 86, 1, 0, 0, 0, 66, 67, 5, 16, 0, 0, 67, 68, 5, 12, 0, 0, 68, 69, 3, 4, 2, 0, 69, 70, 5, 13, 0, 0, 70, 86, 1, 0, 0, 0, 71, 72, 5, 3, 0, 0, 72, 73, 3, 4, 2, 0, 73, 74, 5, 4, 0, 0, 74, 75, 5, 12, 0, 0, 75, 76, 3, 8, 4, 0, 76, 77, 5, 13, 0, 0, 77, 86, 1, 0, 0, 0, 78, 79, 5, 3, 0, 0, 79, 80, 3, 4, 2, 0, 80, 81, 5, 4, 0, 0, 81, 82, 5, 12, 0, 0, 82, 83, 3, 4, 2, 0, 83, 84, 5, 13, 0, 0, 84, 86, 1, 0, 0, 0, 85, 21, 1, 0, 0, 0, 85, 23, 1, 0, 0, 0, 85, 24, 1, 0, 0, 0, 85, 26, 1, 0, 0, 0, 85, 30, 1, 0, 0, 0, 85, 35, 1, 0, 0, 0, 85, 40, 1, 0, 0, 0, 85, 49, 1, 0, 0, 0, 85, 54, 1, 0, 0, 0, 85, 61, 1, 0, 0, 0, 85, 66, 1, 0, 0, 0, 85, 71, 1, 0, 0, 0, 85, 78, 1, 0, 0, 0, 86, 98, 1, 0, 0, 0, 87, 88, 10, 12, 0, 0, 88, 89, 7, 0, 0, 0, 89, 97, 3, 4, 2, 13, 90, 91, 10, 11, 0, 0, 91, 92, 7, 1, 0, 0, 92, 97, 3, 4, 2, 12, 93, 94, 10, 10, 0, 0, 94, 95, 5, 17, 0, 0, 95, 97, 3, 4, 2, 11, 96, 87, 1, 0, 0, 0, 96, 90, 1, 0, 0, 0, 96, 93, 1, 0, 0, 0, 97, 100, 1, 0, 0, 0, 98, 96, 1, 0, 0, 0, 98, 99, 1, 0, 0, 0, 99, 5, 1, 0, 0, 0, 100, 98, 1, 0, 0, 0, 101, 104, 5, 14, 0, 0, 102, 104, 5, 16, 0, 0, 103, 101, 1, 0, 0, 0, 103, 102, 1, 0, 0, 0, 104, 7, 1, 0, 0, 0, 105, 107, 5, 15, 0, 0, 106, 108, 3, 10, 5, 0, 107, 106, 1, 0, 0, 0, 107, 108, 1, 0, 0, 0, 108, 9, 1, 0, 0, 0, 109, 110, 6, 5, -1, 0, 110, 111, 7, 1, 0, 0, 111, 118, 3, 6, 3, 0, 112, 113, 7, 1, 0, 0, 113, 114, 5, 3, 0, 0, 114, 115, 3, 4, 2, 0, 115, 116, 5, 4, 0, 0, 116, 118, 1, 0, 0, 0, 117, 109, 1, 0, 0, 0, 117, 112, 1, 0, 0, 0, 118, 127, 1, 0, 0, 0, 119, 120, 10, 4, 0, 0, 120, 121, 7, 0, 0, 0, 121, 126, 3, 12, 6, 0, 122, 123, 10, 3, 0, 0, 123, 124, 7, 1, 0, 0, 124, 126, 3, 12, 6, 0, 125, 119, 1, 0, 0, 0, 125, 122, 1, 0, 0, 0, 126, 129, 1, 0, 0, 0, 127, 125, 1, 0, 0, 0, 127, 128, 1, 0, 0, 0, 128, 11, 1, 0, 0, 0, 129, 127, 1, 0, 0, 0, 130, 131, 6, 6, -1, 0, 131, 132, 5, 3, 0, 0, 132, 133, 3, 4, 2, 0, 133, 134, 5, 4, 0, 0, 134, 137, 1, 0, 0, 0, 135, 137, 3, 6, 3, 0, 136, 130, 1, 0, 0, 0, 136, 135, 1, 0, 0, 0, 137, 143, 1, 0, 0, 0, 138, 139, 10, 3, 0, 0, 139, 140, 7, 0, 0, 0, 140, 142, 3, 12, 6, 4, 141, 138, 1, 0, 0, 0, 142, 145, 1, 0, 0, 0, 143, 141, 1, 0, 0, 0, 143, 144, 1, 0, 0, 0, 144, 13, 1, 0, 0, 0, 145, 143, 1, 0, 0, 0, 10, 85, 96, 98, 103, 107, 117, 125, 127, 136, 143] \ No newline at end of file diff --git a/src/gems/expression/parsing/antlr/ExprLexer.py b/src/gems/expression/parsing/antlr/ExprLexer.py index 60c2d135..261c7170 100644 --- a/src/gems/expression/parsing/antlr/ExprLexer.py +++ b/src/gems/expression/parsing/antlr/ExprLexer.py @@ -1,8 +1,7 @@ -# Generated from Expr.g4 by ANTLR 4.13.2 -import sys -from io import StringIO - +# Generated from /home/user/GemsPy/grammar/Expr.g4 by ANTLR 4.13.2 from antlr4 import * +from io import StringIO +import sys if sys.version_info[1] > 5: from typing import TextIO diff --git a/src/gems/expression/parsing/antlr/ExprParser.py b/src/gems/expression/parsing/antlr/ExprParser.py index 859d23dc..8d66b7c6 100644 --- a/src/gems/expression/parsing/antlr/ExprParser.py +++ b/src/gems/expression/parsing/antlr/ExprParser.py @@ -1,9 +1,8 @@ -# Generated from Expr.g4 by ANTLR 4.13.2 +# Generated from /home/user/GemsPy/grammar/Expr.g4 by ANTLR 4.13.2 # encoding: utf-8 -import sys -from io import StringIO - from antlr4 import * +from io import StringIO +import sys if sys.version_info[1] > 5: from typing import TextIO @@ -16,7 +15,7 @@ def serializedATN(): 4, 1, 18, - 140, + 147, 2, 0, 7, @@ -173,9 +172,23 @@ def serializedATN(): 2, 1, 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, 3, 2, - 79, + 86, 8, 2, 1, @@ -198,14 +211,14 @@ def serializedATN(): 2, 5, 2, - 90, + 97, 8, 2, 10, 2, 12, 2, - 93, + 100, 9, 2, 1, @@ -214,7 +227,7 @@ def serializedATN(): 3, 3, 3, - 97, + 104, 8, 3, 1, @@ -223,7 +236,7 @@ def serializedATN(): 4, 3, 4, - 101, + 108, 8, 4, 1, @@ -244,7 +257,7 @@ def serializedATN(): 5, 3, 5, - 111, + 118, 8, 5, 1, @@ -261,14 +274,14 @@ def serializedATN(): 5, 5, 5, - 119, + 126, 8, 5, 10, 5, 12, 5, - 122, + 129, 9, 5, 1, @@ -285,7 +298,7 @@ def serializedATN(): 6, 3, 6, - 130, + 137, 8, 6, 1, @@ -296,14 +309,14 @@ def serializedATN(): 6, 5, 6, - 135, + 142, 8, 6, 10, 6, 12, 6, - 138, + 145, 9, 6, 1, @@ -333,7 +346,7 @@ def serializedATN(): 2, 7, 7, - 153, + 161, 0, 14, 1, @@ -347,31 +360,31 @@ def serializedATN(): 0, 0, 4, - 78, + 85, 1, 0, 0, 0, 6, - 96, + 103, 1, 0, 0, 0, 8, - 98, + 105, 1, 0, 0, 0, 10, - 110, + 117, 1, 0, 0, 0, 12, - 129, + 136, 1, 0, 0, @@ -425,13 +438,13 @@ def serializedATN(): -1, 0, 22, - 79, + 86, 3, 6, 3, 0, 23, - 79, + 86, 3, 0, 0, @@ -443,11 +456,11 @@ def serializedATN(): 0, 0, 25, - 79, + 86, 3, 4, 2, - 13, + 14, 26, 27, 5, @@ -467,7 +480,7 @@ def serializedATN(): 0, 0, 29, - 79, + 86, 1, 0, 0, @@ -497,7 +510,7 @@ def serializedATN(): 0, 0, 34, - 79, + 86, 1, 0, 0, @@ -527,7 +540,7 @@ def serializedATN(): 0, 0, 39, - 79, + 86, 1, 0, 0, @@ -581,7 +594,7 @@ def serializedATN(): 0, 0, 48, - 79, + 86, 1, 0, 0, @@ -611,7 +624,7 @@ def serializedATN(): 0, 0, 53, - 79, + 86, 1, 0, 0, @@ -625,73 +638,73 @@ def serializedATN(): 55, 56, 5, - 12, + 3, 0, 0, 56, 57, 3, - 8, 4, + 2, 0, 57, 58, 5, - 13, + 11, 0, 0, 58, - 79, - 1, - 0, - 0, + 59, + 3, + 4, + 2, 0, 59, 60, 5, - 16, + 4, 0, 0, 60, - 61, - 5, - 12, + 86, + 1, + 0, 0, 0, 61, 62, - 3, - 4, - 2, + 5, + 16, + 0, 0, 62, 63, 5, - 13, + 12, 0, 0, 63, - 79, - 1, - 0, - 0, + 64, + 3, + 8, + 4, 0, 64, 65, 5, - 3, + 13, 0, 0, 65, - 66, - 3, - 4, - 2, + 86, + 1, + 0, + 0, 0, 66, 67, 5, - 4, + 16, 0, 0, 67, @@ -703,8 +716,8 @@ def serializedATN(): 68, 69, 3, - 8, 4, + 2, 0, 69, 70, @@ -713,7 +726,7 @@ def serializedATN(): 0, 0, 70, - 79, + 86, 1, 0, 0, @@ -745,8 +758,8 @@ def serializedATN(): 75, 76, 3, + 8, 4, - 2, 0, 76, 77, @@ -755,514 +768,562 @@ def serializedATN(): 0, 0, 77, - 79, + 86, 1, 0, 0, 0, 78, + 79, + 5, + 3, + 0, + 0, + 79, + 80, + 3, + 4, + 2, + 0, + 80, + 81, + 5, + 4, + 0, + 0, + 81, + 82, + 5, + 12, + 0, + 0, + 82, + 83, + 3, + 4, + 2, + 0, + 83, + 84, + 5, + 13, + 0, + 0, + 84, + 86, + 1, + 0, + 0, + 0, + 85, 21, 1, 0, 0, 0, - 78, + 85, 23, 1, 0, 0, 0, - 78, + 85, 24, 1, 0, 0, 0, - 78, + 85, 26, 1, 0, 0, 0, - 78, + 85, 30, 1, 0, 0, 0, - 78, + 85, 35, 1, 0, 0, 0, - 78, + 85, 40, 1, 0, 0, 0, - 78, + 85, 49, 1, 0, 0, 0, - 78, + 85, 54, 1, 0, 0, 0, - 78, - 59, + 85, + 61, 1, 0, 0, 0, - 78, - 64, + 85, + 66, 1, 0, 0, 0, - 78, + 85, 71, 1, 0, 0, 0, - 79, - 91, + 85, + 78, 1, 0, 0, 0, - 80, - 81, + 86, + 98, + 1, + 0, + 0, + 0, + 87, + 88, 10, - 11, + 12, 0, 0, - 81, - 82, + 88, + 89, 7, 0, 0, 0, - 82, - 90, + 89, + 97, 3, 4, 2, - 12, - 83, - 84, - 10, + 13, + 90, + 91, 10, + 11, 0, 0, - 84, - 85, + 91, + 92, 7, 1, 0, 0, - 85, - 90, + 92, + 97, 3, 4, 2, - 11, - 86, - 87, + 12, + 93, + 94, + 10, 10, - 9, 0, 0, - 87, - 88, + 94, + 95, 5, 17, 0, 0, - 88, - 90, + 95, + 97, 3, 4, 2, - 10, - 89, - 80, + 11, + 96, + 87, 1, 0, 0, 0, - 89, - 83, + 96, + 90, 1, 0, 0, 0, - 89, - 86, + 96, + 93, 1, 0, 0, 0, - 90, - 93, + 97, + 100, 1, 0, 0, 0, - 91, - 89, + 98, + 96, 1, 0, 0, 0, - 91, - 92, + 98, + 99, 1, 0, 0, 0, - 92, + 99, 5, 1, 0, 0, 0, - 93, - 91, + 100, + 98, 1, 0, 0, 0, - 94, - 97, + 101, + 104, 5, 14, 0, 0, - 95, - 97, + 102, + 104, 5, 16, 0, 0, - 96, - 94, + 103, + 101, 1, 0, 0, 0, - 96, - 95, + 103, + 102, 1, 0, 0, 0, - 97, + 104, 7, 1, 0, 0, 0, - 98, - 100, + 105, + 107, 5, 15, 0, 0, - 99, - 101, + 106, + 108, 3, 10, 5, 0, - 100, - 99, + 107, + 106, 1, 0, 0, 0, - 100, - 101, + 107, + 108, 1, 0, 0, 0, - 101, + 108, 9, 1, 0, 0, 0, - 102, - 103, + 109, + 110, 6, 5, -1, 0, - 103, - 104, + 110, + 111, 7, 1, 0, 0, - 104, 111, + 118, 3, 6, 3, 0, - 105, - 106, + 112, + 113, 7, 1, 0, 0, - 106, - 107, + 113, + 114, 5, 3, 0, 0, - 107, - 108, + 114, + 115, 3, 4, 2, 0, - 108, - 109, + 115, + 116, 5, 4, 0, 0, - 109, - 111, + 116, + 118, 1, 0, 0, 0, - 110, - 102, + 117, + 109, 1, 0, 0, 0, - 110, - 105, + 117, + 112, 1, 0, 0, 0, - 111, - 120, + 118, + 127, 1, 0, 0, 0, - 112, - 113, + 119, + 120, 10, 4, 0, 0, - 113, - 114, + 120, + 121, 7, 0, 0, 0, - 114, - 119, + 121, + 126, 3, 12, 6, 0, - 115, - 116, + 122, + 123, 10, 3, 0, 0, - 116, - 117, + 123, + 124, 7, 1, 0, 0, - 117, - 119, + 124, + 126, 3, 12, 6, 0, - 118, - 112, + 125, + 119, 1, 0, 0, 0, - 118, - 115, + 125, + 122, 1, 0, 0, 0, - 119, - 122, + 126, + 129, 1, 0, 0, 0, - 120, - 118, + 127, + 125, 1, 0, 0, 0, - 120, - 121, + 127, + 128, 1, 0, 0, 0, - 121, + 128, 11, 1, 0, 0, 0, - 122, - 120, + 129, + 127, 1, 0, 0, 0, - 123, - 124, + 130, + 131, 6, 6, -1, 0, - 124, - 125, + 131, + 132, 5, 3, 0, 0, - 125, - 126, + 132, + 133, 3, 4, 2, 0, - 126, - 127, + 133, + 134, 5, 4, 0, 0, - 127, - 130, + 134, + 137, 1, 0, 0, 0, - 128, - 130, + 135, + 137, 3, 6, 3, 0, - 129, - 123, + 136, + 130, 1, 0, 0, 0, - 129, - 128, + 136, + 135, 1, 0, 0, 0, - 130, - 136, + 137, + 143, 1, 0, 0, 0, - 131, - 132, + 138, + 139, 10, 3, 0, 0, - 132, - 133, + 139, + 140, 7, 0, 0, 0, - 133, - 135, + 140, + 142, 3, 12, 6, 4, - 134, - 131, + 141, + 138, 1, 0, 0, 0, - 135, - 138, + 142, + 145, 1, 0, 0, 0, - 136, - 134, + 143, + 141, 1, 0, 0, 0, - 136, - 137, + 143, + 144, 1, 0, 0, 0, - 137, + 144, 13, 1, 0, 0, 0, - 138, - 136, + 145, + 143, 1, 0, 0, 0, 10, - 78, - 89, - 91, + 85, 96, - 100, - 110, - 118, - 120, - 129, + 98, + 103, + 107, + 117, + 125, + 127, 136, + 143, ] @@ -1477,6 +1538,28 @@ def accept(self, visitor: ParseTreeVisitor): else: return visitor.visitChildren(self) + class BinaryFunctionContext(ExprContext): + def __init__( + self, parser, ctx: ParserRuleContext + ): # actually a ExprParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def IDENTIFIER(self): + return self.getToken(ExprParser.IDENTIFIER, 0) + + def expr(self, i: int = None): + if i is None: + return self.getTypedRuleContexts(ExprParser.ExprContext) + else: + return self.getTypedRuleContext(ExprParser.ExprContext, i) + + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, "visitBinaryFunction"): + return visitor.visitBinaryFunction(self) + else: + return visitor.visitChildren(self) + class NegationContext(ExprContext): def __init__( self, parser, ctx: ParserRuleContext @@ -1748,7 +1831,7 @@ def expr(self, _p: int = 0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 78 + self.state = 85 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input, 0, self._ctx) if la_ == 1: @@ -1775,7 +1858,7 @@ def expr(self, _p: int = 0): self.state = 24 self.match(ExprParser.T__1) self.state = 25 - self.expr(13) + self.expr(14) pass elif la_ == 4: @@ -1855,53 +1938,53 @@ def expr(self, _p: int = 0): pass elif la_ == 9: - localctx = ExprParser.TimeShiftContext(self, localctx) + localctx = ExprParser.BinaryFunctionContext(self, localctx) self._ctx = localctx _prevctx = localctx self.state = 54 self.match(ExprParser.IDENTIFIER) self.state = 55 - self.match(ExprParser.T__11) + self.match(ExprParser.T__2) self.state = 56 - self.shift() + self.expr(0) self.state = 57 - self.match(ExprParser.T__12) + self.match(ExprParser.T__10) + self.state = 58 + self.expr(0) + self.state = 59 + self.match(ExprParser.T__3) pass elif la_ == 10: - localctx = ExprParser.TimeIndexContext(self, localctx) + localctx = ExprParser.TimeShiftContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 59 - self.match(ExprParser.IDENTIFIER) - self.state = 60 - self.match(ExprParser.T__11) self.state = 61 - self.expr(0) + self.match(ExprParser.IDENTIFIER) self.state = 62 + self.match(ExprParser.T__11) + self.state = 63 + self.shift() + self.state = 64 self.match(ExprParser.T__12) pass elif la_ == 11: - localctx = ExprParser.TimeShiftExprContext(self, localctx) + localctx = ExprParser.TimeIndexContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 64 - self.match(ExprParser.T__2) - self.state = 65 - self.expr(0) self.state = 66 - self.match(ExprParser.T__3) + self.match(ExprParser.IDENTIFIER) self.state = 67 self.match(ExprParser.T__11) self.state = 68 - self.shift() + self.expr(0) self.state = 69 self.match(ExprParser.T__12) pass elif la_ == 12: - localctx = ExprParser.TimeIndexExprContext(self, localctx) + localctx = ExprParser.TimeShiftExprContext(self, localctx) self._ctx = localctx _prevctx = localctx self.state = 71 @@ -1913,13 +1996,31 @@ def expr(self, _p: int = 0): self.state = 74 self.match(ExprParser.T__11) self.state = 75 - self.expr(0) + self.shift() self.state = 76 self.match(ExprParser.T__12) pass + elif la_ == 13: + localctx = ExprParser.TimeIndexExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 78 + self.match(ExprParser.T__2) + self.state = 79 + self.expr(0) + self.state = 80 + self.match(ExprParser.T__3) + self.state = 81 + self.match(ExprParser.T__11) + self.state = 82 + self.expr(0) + self.state = 83 + self.match(ExprParser.T__12) + pass + self._ctx.stop = self._input.LT(-1) - self.state = 91 + self.state = 98 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input, 2, self._ctx) while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: @@ -1927,7 +2028,7 @@ def expr(self, _p: int = 0): if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 89 + self.state = 96 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input, 1, self._ctx) if la_ == 1: @@ -1937,14 +2038,14 @@ def expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_expr ) - self.state = 80 - if not self.precpred(self._ctx, 11): + self.state = 87 + if not self.precpred(self._ctx, 12): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( - self, "self.precpred(self._ctx, 11)" + self, "self.precpred(self._ctx, 12)" ) - self.state = 81 + self.state = 88 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 5 or _la == 6): @@ -1952,8 +2053,8 @@ def expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 82 - self.expr(12) + self.state = 89 + self.expr(13) pass elif la_ == 2: @@ -1963,14 +2064,14 @@ def expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_expr ) - self.state = 83 - if not self.precpred(self._ctx, 10): + self.state = 90 + if not self.precpred(self._ctx, 11): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( - self, "self.precpred(self._ctx, 10)" + self, "self.precpred(self._ctx, 11)" ) - self.state = 84 + self.state = 91 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -1978,8 +2079,8 @@ def expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 85 - self.expr(11) + self.state = 92 + self.expr(12) pass elif la_ == 3: @@ -1989,20 +2090,20 @@ def expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_expr ) - self.state = 86 - if not self.precpred(self._ctx, 9): + self.state = 93 + if not self.precpred(self._ctx, 10): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( - self, "self.precpred(self._ctx, 9)" + self, "self.precpred(self._ctx, 10)" ) - self.state = 87 + self.state = 94 self.match(ExprParser.COMPARISON) - self.state = 88 - self.expr(10) + self.state = 95 + self.expr(11) pass - self.state = 93 + self.state = 100 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input, 2, self._ctx) @@ -2065,19 +2166,19 @@ def atom(self): localctx = ExprParser.AtomContext(self, self._ctx, self.state) self.enterRule(localctx, 6, self.RULE_atom) try: - self.state = 96 + self.state = 103 self._errHandler.sync(self) token = self._input.LA(1) if token in [14]: localctx = ExprParser.NumberContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 94 + self.state = 101 self.match(ExprParser.NUMBER) pass elif token in [16]: localctx = ExprParser.IdentifierContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 95 + self.state = 102 self.match(ExprParser.IDENTIFIER) pass else: @@ -2121,13 +2222,13 @@ def shift(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 98 + self.state = 105 self.match(ExprParser.TIME) - self.state = 100 + self.state = 107 self._errHandler.sync(self) _la = self._input.LA(1) if _la == 2 or _la == 7: - self.state = 99 + self.state = 106 self.shift_expr(0) except RecognitionException as re: @@ -2237,7 +2338,7 @@ def shift_expr(self, _p: int = 0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 110 + self.state = 117 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input, 5, self._ctx) if la_ == 1: @@ -2245,7 +2346,7 @@ def shift_expr(self, _p: int = 0): self._ctx = localctx _prevctx = localctx - self.state = 103 + self.state = 110 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -2253,7 +2354,7 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 104 + self.state = 111 self.atom() pass @@ -2261,7 +2362,7 @@ def shift_expr(self, _p: int = 0): localctx = ExprParser.SignedExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 105 + self.state = 112 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -2269,16 +2370,16 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 106 + self.state = 113 self.match(ExprParser.T__2) - self.state = 107 + self.state = 114 self.expr(0) - self.state = 108 + self.state = 115 self.match(ExprParser.T__3) pass self._ctx.stop = self._input.LT(-1) - self.state = 120 + self.state = 127 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input, 7, self._ctx) while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: @@ -2286,7 +2387,7 @@ def shift_expr(self, _p: int = 0): if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 118 + self.state = 125 self._errHandler.sync(self) la_ = self._interp.adaptivePredict(self._input, 6, self._ctx) if la_ == 1: @@ -2299,14 +2400,14 @@ def shift_expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_shift_expr ) - self.state = 112 + self.state = 119 if not self.precpred(self._ctx, 4): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 4)" ) - self.state = 113 + self.state = 120 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 5 or _la == 6): @@ -2314,7 +2415,7 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 114 + self.state = 121 self.right_expr(0) pass @@ -2328,14 +2429,14 @@ def shift_expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_shift_expr ) - self.state = 115 + self.state = 122 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 3)" ) - self.state = 116 + self.state = 123 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -2343,11 +2444,11 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 117 + self.state = 124 self.right_expr(0) pass - self.state = 122 + self.state = 129 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input, 7, self._ctx) @@ -2436,7 +2537,7 @@ def right_expr(self, _p: int = 0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 129 + self.state = 136 self._errHandler.sync(self) token = self._input.LA(1) if token in [3]: @@ -2444,25 +2545,25 @@ def right_expr(self, _p: int = 0): self._ctx = localctx _prevctx = localctx - self.state = 124 + self.state = 131 self.match(ExprParser.T__2) - self.state = 125 + self.state = 132 self.expr(0) - self.state = 126 + self.state = 133 self.match(ExprParser.T__3) pass elif token in [14, 16]: localctx = ExprParser.RightAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 128 + self.state = 135 self.atom() pass else: raise NoViableAltException(self) self._ctx.stop = self._input.LT(-1) - self.state = 136 + self.state = 143 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: @@ -2477,14 +2578,14 @@ def right_expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_right_expr ) - self.state = 131 + self.state = 138 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 3)" ) - self.state = 132 + self.state = 139 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 5 or _la == 6): @@ -2492,9 +2593,9 @@ def right_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 133 + self.state = 140 self.right_expr(4) - self.state = 138 + self.state = 145 self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) @@ -2520,13 +2621,13 @@ def sempred(self, localctx: RuleContext, ruleIndex: int, predIndex: int): def expr_sempred(self, localctx: ExprContext, predIndex: int): if predIndex == 0: - return self.precpred(self._ctx, 11) + return self.precpred(self._ctx, 12) if predIndex == 1: - return self.precpred(self._ctx, 10) + return self.precpred(self._ctx, 11) if predIndex == 2: - return self.precpred(self._ctx, 9) + return self.precpred(self._ctx, 10) def shift_expr_sempred(self, localctx: Shift_exprContext, predIndex: int): if predIndex == 3: diff --git a/src/gems/expression/parsing/antlr/ExprVisitor.py b/src/gems/expression/parsing/antlr/ExprVisitor.py index a98fbbaf..68a41ab8 100644 --- a/src/gems/expression/parsing/antlr/ExprVisitor.py +++ b/src/gems/expression/parsing/antlr/ExprVisitor.py @@ -1,4 +1,4 @@ -# Generated from Expr.g4 by ANTLR 4.13.2 +# Generated from /home/user/GemsPy/grammar/Expr.g4 by ANTLR 4.13.2 from antlr4 import * if "." in __name__: @@ -22,6 +22,10 @@ def visitFullexpr(self, ctx: ExprParser.FullexprContext): def visitPortFieldSum(self, ctx: ExprParser.PortFieldSumContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#binaryFunction. + def visitBinaryFunction(self, ctx: ExprParser.BinaryFunctionContext): + return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#negation. def visitNegation(self, ctx: ExprParser.NegationContext): return self.visitChildren(ctx) diff --git a/src/gems/expression/parsing/parse_expression.py b/src/gems/expression/parsing/parse_expression.py index a0bc80e6..6d4243f4 100644 --- a/src/gems/expression/parsing/parse_expression.py +++ b/src/gems/expression/parsing/parse_expression.py @@ -22,6 +22,8 @@ ComparisonNode, PortFieldAggregatorNode, PortFieldNode, + maximum, + minimum, ) from gems.expression.parsing.antlr.ExprLexer import ExprLexer from gems.expression.parsing.antlr.ExprParser import ExprParser @@ -166,6 +168,20 @@ def visitFunction(self, ctx: ExprParser.FunctionContext) -> ExpressionNode: raise ValueError(f"Encountered invalid function name {function_name}") return fn(operand) + # Visit a parse tree produced by ExprParser#binaryFunction. + def visitBinaryFunction( + self, ctx: ExprParser.BinaryFunctionContext + ) -> ExpressionNode: + function_name: str = ctx.IDENTIFIER().getText() # type: ignore + left: ExpressionNode = ctx.expr(0).accept(self) # type: ignore + right: ExpressionNode = ctx.expr(1).accept(self) # type: ignore + fn = _BINARY_FUNCTIONS.get(function_name, None) + if fn is None: + raise ValueError( + f"Encountered invalid binary function name {function_name}" + ) + return fn(left, right) + # Visit a parse tree produced by ExprParser#shift. def visitShift(self, ctx: ExprParser.ShiftContext) -> ExpressionNode: if ctx.shift_expr() is None: # type: ignore @@ -235,6 +251,13 @@ def visitRightAtom(self, ctx: ExprParser.RightAtomContext) -> ExpressionNode: _FUNCTIONS = { "expec": ExpressionNode.expec, + "floor": ExpressionNode.floor, + "ceil": ExpressionNode.ceil, +} + +_BINARY_FUNCTIONS = { + "max": maximum, + "min": minimum, } diff --git a/src/gems/expression/print.py b/src/gems/expression/print.py index 50308d87..36583f86 100644 --- a/src/gems/expression/print.py +++ b/src/gems/expression/print.py @@ -15,9 +15,13 @@ from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, ExpressionNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -130,6 +134,18 @@ def port_field(self, node: PortFieldNode) -> str: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> str: return f"({visit(node.operand, self)}.{node.aggregator})" + def floor(self, node: FloorNode) -> str: + return f"floor({visit(node.operand, self)})" + + def ceil(self, node: CeilNode) -> str: + return f"ceil({visit(node.operand, self)})" + + def maximum(self, node: MaxNode) -> str: + return f"max({visit(node.left, self)}, {visit(node.right, self)})" + + def minimum(self, node: MinNode) -> str: + return f"min({visit(node.left, self)}, {visit(node.right, self)})" + def print_expr(expression: ExpressionNode) -> str: return visit(expression, PrinterVisitor()) diff --git a/src/gems/expression/visitor.py b/src/gems/expression/visitor.py index 7180b6ab..dc5ad42c 100644 --- a/src/gems/expression/visitor.py +++ b/src/gems/expression/visitor.py @@ -13,6 +13,7 @@ """ Defines abstract base class for visitors of expressions. """ + import typing from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar @@ -20,12 +21,16 @@ from gems.expression.expression import ( AdditionNode, AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, DivisionNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, MultiplicationNode, NegationNode, ParameterNode, @@ -128,6 +133,22 @@ def port_field(self, node: PortFieldNode) -> T: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: ... + @abstractmethod + def floor(self, node: FloorNode) -> T: + ... + + @abstractmethod + def ceil(self, node: CeilNode) -> T: + ... + + @abstractmethod + def maximum(self, node: MaxNode) -> T: + ... + + @abstractmethod + def minimum(self, node: MinNode) -> T: + ... + def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: """ @@ -171,6 +192,14 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: return visitor.port_field(root) elif isinstance(root, PortFieldAggregatorNode): return visitor.port_field_aggregator(root) + elif isinstance(root, FloorNode): + return visitor.floor(root) + elif isinstance(root, CeilNode): + return visitor.ceil(root) + elif isinstance(root, MaxNode): + return visitor.maximum(root) + elif isinstance(root, MinNode): + return visitor.minimum(root) raise ValueError(f"Unknown expression node type {root.__class__}") diff --git a/src/gems/model/common.py b/src/gems/model/common.py index e295c1aa..b4342735 100644 --- a/src/gems/model/common.py +++ b/src/gems/model/common.py @@ -13,6 +13,7 @@ """ Module for common classes used in models. """ + from enum import Enum diff --git a/src/gems/model/model.py b/src/gems/model/model.py index a49e934d..bd5c8b5c 100644 --- a/src/gems/model/model.py +++ b/src/gems/model/model.py @@ -15,6 +15,7 @@ A model allows to define the behaviour for components, by defining parameters, variables, and equations. """ + import itertools from dataclasses import dataclass, field, replace from typing import Any, Dict, Iterable, Optional @@ -179,9 +180,9 @@ def model( return Model( id=id, constraints={c.name: c for c in constraints} if constraints else {}, - binding_constraints={c.name: c for c in binding_constraints} - if binding_constraints - else {}, + binding_constraints=( + {c.name: c for c in binding_constraints} if binding_constraints else {} + ), parameters={p.name: p for p in parameters} if parameters else {}, variables={v.name: v for v in variables} if variables else {}, objective_contributions=objective_contributions, @@ -189,8 +190,10 @@ def model( objective_operational_contribution=objective_operational_contribution, inter_block_dyn=inter_block_dyn, ports=existing_port_names, - port_fields_definitions={d.port_field: d for d in port_fields_definitions} - if port_fields_definitions - else {}, + port_fields_definitions=( + {d.port_field: d for d in port_fields_definitions} + if port_fields_definitions + else {} + ), extra_outputs=extra_outputs, ) diff --git a/src/gems/model/port.py b/src/gems/model/port.py index cb07319b..9eaaed09 100644 --- a/src/gems/model/port.py +++ b/src/gems/model/port.py @@ -28,8 +28,12 @@ from gems.expression.expression import ( AllTimeSumNode, BinaryOperatorNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -165,6 +169,18 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> None: def port_field(self, node: PortFieldNode) -> None: raise ValueError("Port definition cannot reference another port field.") + def floor(self, node: FloorNode) -> None: + visit(node.operand, self) + + def ceil(self, node: CeilNode) -> None: + visit(node.operand, self) + + def maximum(self, node: MaxNode) -> None: + self._visit_binary_op(node) + + def minimum(self, node: MinNode) -> None: + self._visit_binary_op(node) + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: raise ValueError("Port definition cannot contain port field aggregation.") diff --git a/src/gems/simulation/linear_expression.py b/src/gems/simulation/linear_expression.py index 1f882e64..a0584c07 100644 --- a/src/gems/simulation/linear_expression.py +++ b/src/gems/simulation/linear_expression.py @@ -14,6 +14,7 @@ Specific modelling for "instantiated" linear expressions, with only variables and literal coefficients. """ + import dataclasses from dataclasses import dataclass from typing import Callable, Dict, List, Optional, TypeVar, Union @@ -118,7 +119,6 @@ def apply(self, other: "TimeExpansion") -> "TimeExpansion": @dataclass(frozen=True) class TermKey: - """ Utility class to provide key for a term that contains all term information except coefficient """ diff --git a/src/gems/simulation/linearize.py b/src/gems/simulation/linearize.py index 7e8dbf15..0d3c9e71 100644 --- a/src/gems/simulation/linearize.py +++ b/src/gems/simulation/linearize.py @@ -23,12 +23,16 @@ ) from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, CurrentScenarioIndex, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, NoScenarioIndex, NoTimeIndex, OneScenarioIndex, @@ -200,6 +204,18 @@ def _get_scenario(self, scenario_index: ScenarioIndex) -> Optional[int]: def literal(self, node: LiteralNode) -> LinearExpressionData: return LinearExpressionData([], node.value) + def floor(self, node: FloorNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a floor operator.") + + def ceil(self, node: CeilNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a ceil operator.") + + def maximum(self, node: MaxNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a max operator.") + + def minimum(self, node: MinNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a min operator.") + def comparison(self, node: ComparisonNode) -> LinearExpressionData: raise ValueError("Linear expression cannot contain a comparison operator.") diff --git a/src/gems/simulation/output_values.py b/src/gems/simulation/output_values.py index a0bbc5f7..e8d4eed7 100644 --- a/src/gems/simulation/output_values.py +++ b/src/gems/simulation/output_values.py @@ -7,6 +7,7 @@ """ Utility classes to obtain solver results. """ + import math from dataclasses import dataclass, field from typing import Any, Dict, Mapping, Optional, Set, TypeVar diff --git a/src/gems/study/network.py b/src/gems/study/network.py index 0ac27ccc..53a53687 100644 --- a/src/gems/study/network.py +++ b/src/gems/study/network.py @@ -14,6 +14,7 @@ The network module defines the data model for an instance of network, including nodes, links, and components (model instantations). """ + import itertools from dataclasses import dataclass, field, replace from typing import Any, Dict, Iterable, List, cast diff --git a/src/gems/utils.py b/src/gems/utils.py index 23c1ebaa..7964fd24 100644 --- a/src/gems/utils.py +++ b/src/gems/utils.py @@ -13,6 +13,7 @@ """ Module for technical utilities. """ + import json import pathlib from typing import Any, Callable, Dict, Optional, TypeVar diff --git a/tests/e2e/functional/libs/standard.py b/tests/e2e/functional/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/e2e/functional/libs/standard.py +++ b/tests/e2e/functional/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure diff --git a/tests/e2e/functional/test_libs_python_system_python.py b/tests/e2e/functional/test_libs_python_system_python.py index aea1abff..01a7bebe 100644 --- a/tests/e2e/functional/test_libs_python_system_python.py +++ b/tests/e2e/functional/test_libs_python_system_python.py @@ -33,6 +33,7 @@ - Description: Short-term storage behavior over different horizons and efficiencies. - Names: `test_short_test_horizon_10`, `test_short_test_horizon_5`. """ + import pandas as pd import pytest diff --git a/tests/e2e/integration/libs/standard.py b/tests/e2e/integration/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/e2e/integration/libs/standard.py +++ b/tests/e2e/integration/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure diff --git a/tests/e2e/models/poc-various-models/libs/standard.py b/tests/e2e/models/poc-various-models/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/e2e/models/poc-various-models/libs/standard.py +++ b/tests/e2e/models/poc-various-models/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure diff --git a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py index 0d8d7505..1663ff40 100644 --- a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py +++ b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py @@ -129,9 +129,9 @@ def test_electrolyzer_n_inputs_1() -> None: print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -218,9 +218,9 @@ def test_electrolyzer_n_inputs_2() -> None: print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -313,9 +313,9 @@ def test_electrolyzer_n_inputs_3() -> None: ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) @@ -401,9 +401,9 @@ def test_electrolyzer_n_inputs_4() -> None: ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) diff --git a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py index 2e007571..78bfc5a0 100644 --- a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py +++ b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py @@ -135,9 +135,9 @@ def test_electrolyzer_n_inputs_1( print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -235,9 +235,9 @@ def test_electrolyzer_n_inputs_2( print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -342,9 +342,9 @@ def test_electrolyzer_n_inputs_3( ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) @@ -443,9 +443,9 @@ def test_electrolyzer_n_inputs_4( ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) diff --git a/tests/e2e/models/poc-various-models/test_quota_co2.py b/tests/e2e/models/poc-various-models/test_quota_co2.py index 2bd504ee..ee3bc445 100644 --- a/tests/e2e/models/poc-various-models/test_quota_co2.py +++ b/tests/e2e/models/poc-various-models/test_quota_co2.py @@ -85,6 +85,6 @@ def test_quota_co2() -> None: assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 5500) - assert math.isclose(oil1_p, 50) # type:ignore - assert math.isclose(coal1_p, 50) # type:ignore - assert math.isclose(l12_flow, -50) # type:ignore + assert math.isclose(oil1_p, 50) # type: ignore + assert math.isclose(coal1_p, 50) # type: ignore + assert math.isclose(l12_flow, -50) # type: ignore diff --git a/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py b/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py index e0163cc7..812a75cc 100644 --- a/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py +++ b/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py @@ -91,6 +91,6 @@ def test_quota_co2( assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 5500) - assert math.isclose(oil1_p, 50) # type:ignore - assert math.isclose(coal1_p, 50) # type:ignore - assert math.isclose(l12_flow, -50) # type:ignore + assert math.isclose(oil1_p, 50) # type: ignore + assert math.isclose(coal1_p, 50) # type: ignore + assert math.isclose(l12_flow, -50) # type: ignore diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index d6558d7f..5998116f 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -15,7 +15,7 @@ from gems.expression import ExpressionNode, literal, param, print_expr, var from gems.expression.equality import expressions_equal -from gems.expression.expression import port_field +from gems.expression.expression import maximum, minimum, port_field from gems.expression.parsing.parse_expression import ( AntaresParseException, ModelIdentifiers, @@ -119,6 +119,42 @@ "expec(sum(cost * generation))", (param("cost") * var("generation")).time_sum().expec(), ), + ( + {}, + {"p"}, + "floor(p)", + param("p").floor(), + ), + ( + {}, + {"p"}, + "ceil(p)", + param("p").ceil(), + ), + ( + {}, + {"a", "b"}, + "max(a, b)", + maximum(param("a"), param("b")), + ), + ( + {}, + {"a", "b"}, + "min(a, b)", + minimum(param("a"), param("b")), + ), + ( + {}, + {"p", "q"}, + "ceil(p/q)", + (param("p") / param("q")).ceil(), + ), + ( + {}, + {"p", "q"}, + "max(0, ceil(p/q))", + maximum(literal(0), (param("p") / param("q")).ceil()), + ), ], ) def test_parsing_visitor( diff --git a/tests/unittests/expressions/visitor/test_degree.py b/tests/unittests/expressions/visitor/test_degree.py index 72996d64..0df558c1 100644 --- a/tests/unittests/expressions/visitor/test_degree.py +++ b/tests/unittests/expressions/visitor/test_degree.py @@ -10,9 +10,20 @@ # # This file is part of the Antares project. +import math + import pytest -from gems.expression import ExpressionDegreeVisitor, LiteralNode, param, var, visit +from gems.expression import ( + ExpressionDegreeVisitor, + LiteralNode, + maximum, + minimum, + param, + var, + visit, +) +from gems.expression.expression import CeilNode, FloorNode def test_degree() -> None: @@ -26,6 +37,28 @@ def test_degree() -> None: assert visit(expr, ExpressionDegreeVisitor()) == 2 +def test_floor_ceil_degree() -> None: + x = var("x") + p = param("p") + + assert visit(FloorNode(p), ExpressionDegreeVisitor()) == 0 + assert visit(CeilNode(p), ExpressionDegreeVisitor()) == 0 + assert visit(FloorNode(x), ExpressionDegreeVisitor()) == math.inf + assert visit(CeilNode(x), ExpressionDegreeVisitor()) == math.inf + + +def test_max_min_degree() -> None: + x = var("x") + p = param("p") + q = param("q") + + assert visit(maximum(p, q), ExpressionDegreeVisitor()) == 0 + assert visit(minimum(p, q), ExpressionDegreeVisitor()) == 0 + assert visit(maximum(x, p), ExpressionDegreeVisitor()) == math.inf + assert visit(minimum(p, x), ExpressionDegreeVisitor()) == math.inf + assert visit(maximum(x, x), ExpressionDegreeVisitor()) == math.inf + + @pytest.mark.xfail(reason="Degree simplification not implemented") def test_degree_computation_should_take_into_account_simplifications() -> None: x = var("x") diff --git a/tests/unittests/expressions/visitor/test_equality.py b/tests/unittests/expressions/visitor/test_equality.py index bcdec8ff..b2efd0fb 100644 --- a/tests/unittests/expressions/visitor/test_equality.py +++ b/tests/unittests/expressions/visitor/test_equality.py @@ -14,6 +14,7 @@ from gems.expression import ExpressionNode, copy_expression, literal, param, var from gems.expression.equality import expressions_equal +from gems.expression.expression import maximum, minimum @pytest.mark.parametrize( @@ -30,6 +31,10 @@ var("x").time_sum(), var("x") + 5 <= 2, var("x").expec(), + var("x").floor(), + var("x").ceil(), + maximum(var("x"), param("p")), + minimum(var("x"), param("p")), ], ) def test_equals(expr: ExpressionNode) -> None: @@ -52,6 +57,14 @@ def test_equals(expr: ExpressionNode) -> None: var("x").time_sum(1, 10), ), (var("x").expec(), var("y").expec()), + # floor / ceil + (var("x").floor(), var("y").floor()), + (var("x").ceil(), var("y").ceil()), + (var("x").floor(), var("x").ceil()), # different node type + # max / min + (maximum(var("x"), param("p")), maximum(var("y"), param("p"))), + (minimum(var("x"), param("p")), minimum(var("x"), param("q"))), + (maximum(var("x"), param("p")), minimum(var("x"), param("p"))), # Max vs Min ], ) def test_not_equals(lhs: ExpressionNode, rhs: ExpressionNode) -> None: diff --git a/tests/unittests/expressions/visitor/test_evaluation.py b/tests/unittests/expressions/visitor/test_evaluation.py index 62e25183..5566619c 100644 --- a/tests/unittests/expressions/visitor/test_evaluation.py +++ b/tests/unittests/expressions/visitor/test_evaluation.py @@ -111,3 +111,22 @@ def test_sum_expressions() -> None: assert expressions_equal( sum_expressions([literal(1), var("x"), param("p")]), 1 + (var("x") + param("p")) ) + + +def test_floor_ceil_max_min() -> None: + from gems.expression.expression import maximum, minimum + + context = EvaluationContext(parameters={"p": 2.7, "q": 1.3}) + + assert visit(param("p").floor(), EvaluationVisitor(context)) == 2.0 + assert visit(param("p").ceil(), EvaluationVisitor(context)) == 3.0 + assert visit( + maximum(param("p"), param("q")), EvaluationVisitor(context) + ) == pytest.approx(2.7) + assert visit( + minimum(param("p"), param("q")), EvaluationVisitor(context) + ) == pytest.approx(1.3) + assert visit( + maximum(literal(0), param("q")), EvaluationVisitor(context) + ) == pytest.approx(1.3) + assert visit(maximum(literal(0), -param("p")), EvaluationVisitor(context)) == 0.0 diff --git a/tests/unittests/expressions/visitor/test_printer.py b/tests/unittests/expressions/visitor/test_printer.py index a41250af..f9cc82b1 100644 --- a/tests/unittests/expressions/visitor/test_printer.py +++ b/tests/unittests/expressions/visitor/test_printer.py @@ -19,3 +19,20 @@ def test_comparison() -> None: expr: ExpressionNode = (5 * x + 3) >= p - 2 assert visit(expr, PrinterVisitor()) == "((5.0 * x) + 3.0) >= (p - 2.0)" + + +def test_floor_ceil_max_min_printer() -> None: + from gems.expression.expression import maximum, minimum + + p = param("p") + q = param("q") + + assert visit(p.floor(), PrinterVisitor()) == "floor(p)" + assert visit(p.ceil(), PrinterVisitor()) == "ceil(p)" + assert visit(maximum(p, q), PrinterVisitor()) == "max(p, q)" + assert visit(minimum(p, q), PrinterVisitor()) == "min(p, q)" + assert visit((p / q).ceil(), PrinterVisitor()) == "ceil((p / q))" + assert ( + visit(maximum(param("a"), (p / q).ceil()), PrinterVisitor()) + == "max(a, ceil((p / q)))" + ) diff --git a/tests/unittests/lib_parsing/test_objective_contribution.py b/tests/unittests/lib_parsing/test_objective_contribution.py index 49908930..18b48991 100644 --- a/tests/unittests/lib_parsing/test_objective_contribution.py +++ b/tests/unittests/lib_parsing/test_objective_contribution.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MPL-2.0 """Tests for objective creation logic and coefficient accumulation.""" + from typing import Any, Dict, Optional from unittest.mock import Mock, patch @@ -97,7 +98,7 @@ def name(self) -> str: def _setup_mock_optimization_environment( - linear_expressions: Dict[Any, LinearExpression] + linear_expressions: Dict[Any, LinearExpression], ) -> Any: """Sets up a mock context and problem with real OR-Tools solver objects.""" diff --git a/tests/unittests/system/libs/standard.py b/tests/unittests/system/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/unittests/system/libs/standard.py +++ b/tests/unittests/system/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure