diff --git a/.gitignore b/.gitignore index 1edab6c2..59318f35 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ bin outputs build src/gems.egg-info/ +src/gemspy.egg-info/ diff --git a/grammar/Expr.g4 b/grammar/Expr.g4 index 73e4d0ba..f990402f 100644 --- a/grammar/Expr.g4 +++ b/grammar/Expr.g4 @@ -28,13 +28,15 @@ expr | 'sum' '(' expr ')' # allTimeSum | 'sum_connections' '(' portFieldExpr ')' # portFieldSum | 'sum' '(' from=shift '..' to=shift ',' expr ')' # timeSum - | IDENTIFIER '(' expr ')' # function + | IDENTIFIER '(' argList? ')' # function | IDENTIFIER '[' shift ']' # timeShift | IDENTIFIER '[' expr ']' # timeIndex | '(' expr ')' '[' shift ']' # timeShiftExpr | '(' expr ')' '[' expr ']' # timeIndexExpr ; +argList : expr (',' expr)* ; + atom : NUMBER # number | IDENTIFIER # identifier diff --git a/src/gems/expression/__init__.py b/src/gems/expression/__init__.py index 3d949b6e..1eafc790 100644 --- a/src/gems/expression/__init__.py +++ b/src/gems/expression/__init__.py @@ -20,16 +20,22 @@ ) from .expression import ( AdditionNode, + CeilNode, Comparator, ComparisonNode, DivisionNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, MultiplicationNode, NegationNode, ParameterNode, VariableNode, literal, + maximum, + minimum, param, sum_expressions, var, diff --git a/src/gems/expression/copy.py b/src/gems/expression/copy.py index eea34d74..a5276ccd 100644 --- a/src/gems/expression/copy.py +++ b/src/gems/expression/copy.py @@ -15,11 +15,15 @@ from .expression import ( AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, ParameterNode, PortFieldAggregatorNode, PortFieldNode, @@ -95,6 +99,18 @@ def port_field(self, node: PortFieldNode) -> ExpressionNode: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: return PortFieldAggregatorNode(visit(node.operand, self), node.aggregator) + def floor(self, node: FloorNode) -> ExpressionNode: + return FloorNode(visit(node.operand, self)) + + def ceil(self, node: CeilNode) -> ExpressionNode: + return CeilNode(visit(node.operand, self)) + + def maximum(self, node: MaxNode) -> ExpressionNode: + return MaxNode([visit(op, self) for op in node.operands]) + + def minimum(self, node: MinNode) -> ExpressionNode: + return MinNode([visit(op, self) for op in node.operands]) + def copy_expression(expression: ExpressionNode) -> ExpressionNode: return visit(expression, CopyVisitor()) diff --git a/src/gems/expression/degree.py b/src/gems/expression/degree.py index d9051818..220650c1 100644 --- a/src/gems/expression/degree.py +++ b/src/gems/expression/degree.py @@ -10,11 +10,17 @@ # # This file is part of the Antares project. +import math + import gems.expression.scenario_operator from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -39,77 +45,91 @@ from .visitor import ExpressionVisitor, T, visit -class ExpressionDegreeVisitor(ExpressionVisitor[int]): +class ExpressionDegreeVisitor(ExpressionVisitor[int | float]): """ Computes degree of expression with respect to variables. """ - def literal(self, node: LiteralNode) -> int: + def literal(self, node: LiteralNode) -> int | float: return 0 - def negation(self, node: NegationNode) -> int: + def negation(self, node: NegationNode) -> int | float: return visit(node.operand, self) # TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div - def addition(self, node: AdditionNode) -> int: + def addition(self, node: AdditionNode) -> int | float: degrees = [visit(o, self) for o in node.operands] return max(degrees) - def multiplication(self, node: MultiplicationNode) -> int: + def multiplication(self, node: MultiplicationNode) -> int | float: return visit(node.left, self) + visit(node.right, self) - def division(self, node: DivisionNode) -> int: + def division(self, node: DivisionNode) -> int | float: right_degree = visit(node.right, self) if right_degree != 0: raise ValueError("Degree computation not implemented for divisions.") return visit(node.left, self) - def comparison(self, node: ComparisonNode) -> int: + def comparison(self, node: ComparisonNode) -> int | float: return max(visit(node.left, self), visit(node.right, self)) - def variable(self, node: VariableNode) -> int: + def variable(self, node: VariableNode) -> int | float: return 1 - def parameter(self, node: ParameterNode) -> int: + def parameter(self, node: ParameterNode) -> int | float: return 0 - def comp_variable(self, node: ComponentVariableNode) -> int: + def comp_variable(self, node: ComponentVariableNode) -> int | float: return 1 - def comp_parameter(self, node: ComponentParameterNode) -> int: + def comp_parameter(self, node: ComponentParameterNode) -> int | float: return 0 - def pb_variable(self, node: ProblemVariableNode) -> int: + def pb_variable(self, node: ProblemVariableNode) -> int | float: return 1 - def pb_parameter(self, node: ProblemParameterNode) -> int: + def pb_parameter(self, node: ProblemParameterNode) -> int | float: return 0 - def time_shift(self, node: TimeShiftNode) -> int: + def time_shift(self, node: TimeShiftNode) -> int | float: return visit(node.operand, self) - def time_eval(self, node: TimeEvalNode) -> int: + def time_eval(self, node: TimeEvalNode) -> int | float: return visit(node.operand, self) - def time_sum(self, node: TimeSumNode) -> int: + def time_sum(self, node: TimeSumNode) -> int | float: return visit(node.operand, self) - def all_time_sum(self, node: AllTimeSumNode) -> int: + def all_time_sum(self, node: AllTimeSumNode) -> int | float: return visit(node.operand, self) - def scenario_operator(self, node: ScenarioOperatorNode) -> int: + def scenario_operator(self, node: ScenarioOperatorNode) -> int | float: scenario_operator_cls = getattr(gems.expression.scenario_operator, node.name) # TODO: Carefully check if this formula is correct return scenario_operator_cls.degree() * visit(node.operand, self) - def port_field(self, node: PortFieldNode) -> int: + def port_field(self, node: PortFieldNode) -> int | float: return 1 - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int: + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int | float: return visit(node.operand, self) + def floor(self, node: FloorNode) -> int | float: + d = visit(node.operand, self) + return 0 if d == 0 else math.inf + + def ceil(self, node: CeilNode) -> int | float: + d = visit(node.operand, self) + return 0 if d == 0 else math.inf + + def maximum(self, node: MaxNode) -> int | float: + return 0 if all(visit(op, self) == 0 for op in node.operands) else math.inf + + def minimum(self, node: MinNode) -> int | float: + return 0 if all(visit(op, self) == 0 for op in node.operands) else math.inf + -def compute_degree(expression: ExpressionNode) -> int: +def compute_degree(expression: ExpressionNode) -> int | float: return visit(expression, ExpressionDegreeVisitor()) diff --git a/src/gems/expression/equality.py b/src/gems/expression/equality.py index d5a543ff..b96ddaa0 100644 --- a/src/gems/expression/equality.py +++ b/src/gems/expression/equality.py @@ -28,8 +28,12 @@ from gems.expression.expression import ( AllTimeSumNode, BinaryOperatorNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -111,6 +115,14 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: right, PortFieldAggregatorNode ): return self.port_field_aggregator(left, right) + if isinstance(left, FloorNode) and isinstance(right, FloorNode): + return self.floor(left, right) + if isinstance(left, CeilNode) and isinstance(right, CeilNode): + return self.ceil(left, right) + if isinstance(left, MaxNode) and isinstance(right, MaxNode): + return self.maximum(left, right) + if isinstance(left, MinNode) and isinstance(right, MinNode): + return self.minimum(left, right) raise NotImplementedError(f"Equality not implemented for {left.__class__}") def literal(self, left: LiteralNode, right: LiteralNode) -> bool: @@ -215,6 +227,22 @@ def port_field_aggregator( left.operand, right.operand ) + def floor(self, left: FloorNode, right: FloorNode) -> bool: + return self.visit(left.operand, right.operand) + + def ceil(self, left: CeilNode, right: CeilNode) -> bool: + return self.visit(left.operand, right.operand) + + def maximum(self, left: MaxNode, right: MaxNode) -> bool: + return len(left.operands) == len(right.operands) and all( + self.visit(l, r) for l, r in zip(left.operands, right.operands) + ) + + def minimum(self, left: MinNode, right: MinNode) -> bool: + return len(left.operands) == len(right.operands) and all( + self.visit(l, r) for l, r in zip(left.operands, right.operands) + ) + def expressions_equal( left: ExpressionNode, right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0 diff --git a/src/gems/expression/evaluate.py b/src/gems/expression/evaluate.py index 398db7da..2c850eff 100644 --- a/src/gems/expression/evaluate.py +++ b/src/gems/expression/evaluate.py @@ -10,14 +10,19 @@ # # This file is part of the Antares project. +import math from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -139,6 +144,18 @@ def port_field(self, node: PortFieldNode) -> float: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: raise NotImplementedError() + def floor(self, node: FloorNode) -> float: + return float(math.floor(visit(node.operand, self))) + + def ceil(self, node: CeilNode) -> float: + return float(math.ceil(visit(node.operand, self))) + + def maximum(self, node: MaxNode) -> float: + return max(visit(op, self) for op in node.operands) + + def minimum(self, node: MinNode) -> float: + return min(visit(op, self) for op in node.operands) + def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float: return visit(expression, EvaluationVisitor(value_provider)) diff --git a/src/gems/expression/expression.py b/src/gems/expression/expression.py index 44bfcfc9..68dcee3d 100644 --- a/src/gems/expression/expression.py +++ b/src/gems/expression/expression.py @@ -13,6 +13,7 @@ """ Defines the model for generic expressions. """ + import enum import inspect from dataclasses import dataclass @@ -117,6 +118,12 @@ def shift(self, shift: AnyExpression) -> "ExpressionNode": def eval(self, time: AnyExpression) -> "ExpressionNode": return TimeEvalNode(self, _wrap_in_node(time)) + def floor(self) -> "ExpressionNode": + return FloorNode(self) + + def ceil(self) -> "ExpressionNode": + return CeilNode(self) + def expec(self) -> "ExpressionNode": return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) @@ -368,6 +375,16 @@ class NegationNode(UnaryOperatorNode): pass +@dataclass(frozen=True, eq=False) +class FloorNode(UnaryOperatorNode): + pass + + +@dataclass(frozen=True, eq=False) +class CeilNode(UnaryOperatorNode): + pass + + @dataclass(frozen=True, eq=False) class BinaryOperatorNode(ExpressionNode): left: ExpressionNode @@ -400,6 +417,24 @@ class DivisionNode(BinaryOperatorNode): pass +@dataclass(frozen=True, eq=False) +class MaxNode(ExpressionNode): + operands: List[ExpressionNode] + + +@dataclass(frozen=True, eq=False) +class MinNode(ExpressionNode): + operands: List[ExpressionNode] + + +def maximum(*operands: "ExpressionNode") -> "MaxNode": + return MaxNode(list(operands)) + + +def minimum(*operands: "ExpressionNode") -> "MinNode": + return MinNode(list(operands)) + + @dataclass(frozen=True, eq=False) class TimeShiftNode(UnaryOperatorNode): time_shift: ExpressionNode diff --git a/src/gems/expression/indexing.py b/src/gems/expression/indexing.py index 022aa936..c7858204 100644 --- a/src/gems/expression/indexing.py +++ b/src/gems/expression/indexing.py @@ -19,12 +19,16 @@ from .expression import ( AdditionNode, AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, DivisionNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, MultiplicationNode, NegationNode, ParameterNode, @@ -159,6 +163,18 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStruct "Port fields aggregators must be resolved before computing indexing structure." ) + def floor(self, node: FloorNode) -> IndexingStructure: + return visit(node.operand, self) + + def ceil(self, node: CeilNode) -> IndexingStructure: + return visit(node.operand, self) + + def maximum(self, node: MaxNode) -> IndexingStructure: + return self._combine(node.operands) + + def minimum(self, node: MinNode) -> IndexingStructure: + return self._combine(node.operands) + def compute_indexation( expression: ExpressionNode, provider: IndexingStructureProvider diff --git a/src/gems/expression/parsing/antlr/Expr.interp b/src/gems/expression/parsing/antlr/Expr.interp index 94f1370f..d2035f0c 100644 --- a/src/gems/expression/parsing/antlr/Expr.interp +++ b/src/gems/expression/parsing/antlr/Expr.interp @@ -44,6 +44,7 @@ rule names: portFieldExpr fullexpr expr +argList atom shift shift_expr @@ -51,4 +52,4 @@ right_expr atn: -[4, 1, 18, 140, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 79, 8, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 90, 8, 2, 10, 2, 12, 2, 93, 9, 2, 1, 3, 1, 3, 3, 3, 97, 8, 3, 1, 4, 1, 4, 3, 4, 101, 8, 4, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 3, 5, 111, 8, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 5, 5, 119, 8, 5, 10, 5, 12, 5, 122, 9, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 130, 8, 6, 1, 6, 1, 6, 1, 6, 5, 6, 135, 8, 6, 10, 6, 12, 6, 138, 9, 6, 1, 6, 0, 3, 4, 10, 12, 7, 0, 2, 4, 6, 8, 10, 12, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 153, 0, 14, 1, 0, 0, 0, 2, 18, 1, 0, 0, 0, 4, 78, 1, 0, 0, 0, 6, 96, 1, 0, 0, 0, 8, 98, 1, 0, 0, 0, 10, 110, 1, 0, 0, 0, 12, 129, 1, 0, 0, 0, 14, 15, 5, 16, 0, 0, 15, 16, 5, 1, 0, 0, 16, 17, 5, 16, 0, 0, 17, 1, 1, 0, 0, 0, 18, 19, 3, 4, 2, 0, 19, 20, 5, 0, 0, 1, 20, 3, 1, 0, 0, 0, 21, 22, 6, 2, -1, 0, 22, 79, 3, 6, 3, 0, 23, 79, 3, 0, 0, 0, 24, 25, 5, 2, 0, 0, 25, 79, 3, 4, 2, 13, 26, 27, 5, 3, 0, 0, 27, 28, 3, 4, 2, 0, 28, 29, 5, 4, 0, 0, 29, 79, 1, 0, 0, 0, 30, 31, 5, 8, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 4, 2, 0, 33, 34, 5, 4, 0, 0, 34, 79, 1, 0, 0, 0, 35, 36, 5, 9, 0, 0, 36, 37, 5, 3, 0, 0, 37, 38, 3, 0, 0, 0, 38, 39, 5, 4, 0, 0, 39, 79, 1, 0, 0, 0, 40, 41, 5, 8, 0, 0, 41, 42, 5, 3, 0, 0, 42, 43, 3, 8, 4, 0, 43, 44, 5, 10, 0, 0, 44, 45, 3, 8, 4, 0, 45, 46, 5, 11, 0, 0, 46, 47, 3, 4, 2, 0, 47, 48, 5, 4, 0, 0, 48, 79, 1, 0, 0, 0, 49, 50, 5, 16, 0, 0, 50, 51, 5, 3, 0, 0, 51, 52, 3, 4, 2, 0, 52, 53, 5, 4, 0, 0, 53, 79, 1, 0, 0, 0, 54, 55, 5, 16, 0, 0, 55, 56, 5, 12, 0, 0, 56, 57, 3, 8, 4, 0, 57, 58, 5, 13, 0, 0, 58, 79, 1, 0, 0, 0, 59, 60, 5, 16, 0, 0, 60, 61, 5, 12, 0, 0, 61, 62, 3, 4, 2, 0, 62, 63, 5, 13, 0, 0, 63, 79, 1, 0, 0, 0, 64, 65, 5, 3, 0, 0, 65, 66, 3, 4, 2, 0, 66, 67, 5, 4, 0, 0, 67, 68, 5, 12, 0, 0, 68, 69, 3, 8, 4, 0, 69, 70, 5, 13, 0, 0, 70, 79, 1, 0, 0, 0, 71, 72, 5, 3, 0, 0, 72, 73, 3, 4, 2, 0, 73, 74, 5, 4, 0, 0, 74, 75, 5, 12, 0, 0, 75, 76, 3, 4, 2, 0, 76, 77, 5, 13, 0, 0, 77, 79, 1, 0, 0, 0, 78, 21, 1, 0, 0, 0, 78, 23, 1, 0, 0, 0, 78, 24, 1, 0, 0, 0, 78, 26, 1, 0, 0, 0, 78, 30, 1, 0, 0, 0, 78, 35, 1, 0, 0, 0, 78, 40, 1, 0, 0, 0, 78, 49, 1, 0, 0, 0, 78, 54, 1, 0, 0, 0, 78, 59, 1, 0, 0, 0, 78, 64, 1, 0, 0, 0, 78, 71, 1, 0, 0, 0, 79, 91, 1, 0, 0, 0, 80, 81, 10, 11, 0, 0, 81, 82, 7, 0, 0, 0, 82, 90, 3, 4, 2, 12, 83, 84, 10, 10, 0, 0, 84, 85, 7, 1, 0, 0, 85, 90, 3, 4, 2, 11, 86, 87, 10, 9, 0, 0, 87, 88, 5, 17, 0, 0, 88, 90, 3, 4, 2, 10, 89, 80, 1, 0, 0, 0, 89, 83, 1, 0, 0, 0, 89, 86, 1, 0, 0, 0, 90, 93, 1, 0, 0, 0, 91, 89, 1, 0, 0, 0, 91, 92, 1, 0, 0, 0, 92, 5, 1, 0, 0, 0, 93, 91, 1, 0, 0, 0, 94, 97, 5, 14, 0, 0, 95, 97, 5, 16, 0, 0, 96, 94, 1, 0, 0, 0, 96, 95, 1, 0, 0, 0, 97, 7, 1, 0, 0, 0, 98, 100, 5, 15, 0, 0, 99, 101, 3, 10, 5, 0, 100, 99, 1, 0, 0, 0, 100, 101, 1, 0, 0, 0, 101, 9, 1, 0, 0, 0, 102, 103, 6, 5, -1, 0, 103, 104, 7, 1, 0, 0, 104, 111, 3, 6, 3, 0, 105, 106, 7, 1, 0, 0, 106, 107, 5, 3, 0, 0, 107, 108, 3, 4, 2, 0, 108, 109, 5, 4, 0, 0, 109, 111, 1, 0, 0, 0, 110, 102, 1, 0, 0, 0, 110, 105, 1, 0, 0, 0, 111, 120, 1, 0, 0, 0, 112, 113, 10, 4, 0, 0, 113, 114, 7, 0, 0, 0, 114, 119, 3, 12, 6, 0, 115, 116, 10, 3, 0, 0, 116, 117, 7, 1, 0, 0, 117, 119, 3, 12, 6, 0, 118, 112, 1, 0, 0, 0, 118, 115, 1, 0, 0, 0, 119, 122, 1, 0, 0, 0, 120, 118, 1, 0, 0, 0, 120, 121, 1, 0, 0, 0, 121, 11, 1, 0, 0, 0, 122, 120, 1, 0, 0, 0, 123, 124, 6, 6, -1, 0, 124, 125, 5, 3, 0, 0, 125, 126, 3, 4, 2, 0, 126, 127, 5, 4, 0, 0, 127, 130, 1, 0, 0, 0, 128, 130, 3, 6, 3, 0, 129, 123, 1, 0, 0, 0, 129, 128, 1, 0, 0, 0, 130, 136, 1, 0, 0, 0, 131, 132, 10, 3, 0, 0, 132, 133, 7, 0, 0, 0, 133, 135, 3, 12, 6, 4, 134, 131, 1, 0, 0, 0, 135, 138, 1, 0, 0, 0, 136, 134, 1, 0, 0, 0, 136, 137, 1, 0, 0, 0, 137, 13, 1, 0, 0, 0, 138, 136, 1, 0, 0, 0, 10, 78, 89, 91, 96, 100, 110, 118, 120, 129, 136] \ No newline at end of file +[4, 1, 18, 151, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 55, 8, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 2, 82, 8, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 93, 8, 2, 10, 2, 12, 2, 96, 9, 2, 1, 3, 1, 3, 1, 3, 5, 3, 101, 8, 3, 10, 3, 12, 3, 104, 9, 3, 1, 4, 1, 4, 3, 4, 108, 8, 4, 1, 5, 1, 5, 3, 5, 112, 8, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 122, 8, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 5, 6, 130, 8, 6, 10, 6, 12, 6, 133, 9, 6, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 3, 7, 141, 8, 7, 1, 7, 1, 7, 1, 7, 5, 7, 146, 8, 7, 10, 7, 12, 7, 149, 9, 7, 1, 7, 0, 3, 4, 12, 14, 8, 0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 165, 0, 16, 1, 0, 0, 0, 2, 20, 1, 0, 0, 0, 4, 81, 1, 0, 0, 0, 6, 97, 1, 0, 0, 0, 8, 107, 1, 0, 0, 0, 10, 109, 1, 0, 0, 0, 12, 121, 1, 0, 0, 0, 14, 140, 1, 0, 0, 0, 16, 17, 5, 16, 0, 0, 17, 18, 5, 1, 0, 0, 18, 19, 5, 16, 0, 0, 19, 1, 1, 0, 0, 0, 20, 21, 3, 4, 2, 0, 21, 22, 5, 0, 0, 1, 22, 3, 1, 0, 0, 0, 23, 24, 6, 2, -1, 0, 24, 82, 3, 8, 4, 0, 25, 82, 3, 0, 0, 0, 26, 27, 5, 2, 0, 0, 27, 82, 3, 4, 2, 13, 28, 29, 5, 3, 0, 0, 29, 30, 3, 4, 2, 0, 30, 31, 5, 4, 0, 0, 31, 82, 1, 0, 0, 0, 32, 33, 5, 8, 0, 0, 33, 34, 5, 3, 0, 0, 34, 35, 3, 4, 2, 0, 35, 36, 5, 4, 0, 0, 36, 82, 1, 0, 0, 0, 37, 38, 5, 9, 0, 0, 38, 39, 5, 3, 0, 0, 39, 40, 3, 0, 0, 0, 40, 41, 5, 4, 0, 0, 41, 82, 1, 0, 0, 0, 42, 43, 5, 8, 0, 0, 43, 44, 5, 3, 0, 0, 44, 45, 3, 10, 5, 0, 45, 46, 5, 10, 0, 0, 46, 47, 3, 10, 5, 0, 47, 48, 5, 11, 0, 0, 48, 49, 3, 4, 2, 0, 49, 50, 5, 4, 0, 0, 50, 82, 1, 0, 0, 0, 51, 52, 5, 16, 0, 0, 52, 54, 5, 3, 0, 0, 53, 55, 3, 6, 3, 0, 54, 53, 1, 0, 0, 0, 54, 55, 1, 0, 0, 0, 55, 56, 1, 0, 0, 0, 56, 82, 5, 4, 0, 0, 57, 58, 5, 16, 0, 0, 58, 59, 5, 12, 0, 0, 59, 60, 3, 10, 5, 0, 60, 61, 5, 13, 0, 0, 61, 82, 1, 0, 0, 0, 62, 63, 5, 16, 0, 0, 63, 64, 5, 12, 0, 0, 64, 65, 3, 4, 2, 0, 65, 66, 5, 13, 0, 0, 66, 82, 1, 0, 0, 0, 67, 68, 5, 3, 0, 0, 68, 69, 3, 4, 2, 0, 69, 70, 5, 4, 0, 0, 70, 71, 5, 12, 0, 0, 71, 72, 3, 10, 5, 0, 72, 73, 5, 13, 0, 0, 73, 82, 1, 0, 0, 0, 74, 75, 5, 3, 0, 0, 75, 76, 3, 4, 2, 0, 76, 77, 5, 4, 0, 0, 77, 78, 5, 12, 0, 0, 78, 79, 3, 4, 2, 0, 79, 80, 5, 13, 0, 0, 80, 82, 1, 0, 0, 0, 81, 23, 1, 0, 0, 0, 81, 25, 1, 0, 0, 0, 81, 26, 1, 0, 0, 0, 81, 28, 1, 0, 0, 0, 81, 32, 1, 0, 0, 0, 81, 37, 1, 0, 0, 0, 81, 42, 1, 0, 0, 0, 81, 51, 1, 0, 0, 0, 81, 57, 1, 0, 0, 0, 81, 62, 1, 0, 0, 0, 81, 67, 1, 0, 0, 0, 81, 74, 1, 0, 0, 0, 82, 94, 1, 0, 0, 0, 83, 84, 10, 11, 0, 0, 84, 85, 7, 0, 0, 0, 85, 93, 3, 4, 2, 12, 86, 87, 10, 10, 0, 0, 87, 88, 7, 1, 0, 0, 88, 93, 3, 4, 2, 11, 89, 90, 10, 9, 0, 0, 90, 91, 5, 17, 0, 0, 91, 93, 3, 4, 2, 10, 92, 83, 1, 0, 0, 0, 92, 86, 1, 0, 0, 0, 92, 89, 1, 0, 0, 0, 93, 96, 1, 0, 0, 0, 94, 92, 1, 0, 0, 0, 94, 95, 1, 0, 0, 0, 95, 5, 1, 0, 0, 0, 96, 94, 1, 0, 0, 0, 97, 102, 3, 4, 2, 0, 98, 99, 5, 11, 0, 0, 99, 101, 3, 4, 2, 0, 100, 98, 1, 0, 0, 0, 101, 104, 1, 0, 0, 0, 102, 100, 1, 0, 0, 0, 102, 103, 1, 0, 0, 0, 103, 7, 1, 0, 0, 0, 104, 102, 1, 0, 0, 0, 105, 108, 5, 14, 0, 0, 106, 108, 5, 16, 0, 0, 107, 105, 1, 0, 0, 0, 107, 106, 1, 0, 0, 0, 108, 9, 1, 0, 0, 0, 109, 111, 5, 15, 0, 0, 110, 112, 3, 12, 6, 0, 111, 110, 1, 0, 0, 0, 111, 112, 1, 0, 0, 0, 112, 11, 1, 0, 0, 0, 113, 114, 6, 6, -1, 0, 114, 115, 7, 1, 0, 0, 115, 122, 3, 8, 4, 0, 116, 117, 7, 1, 0, 0, 117, 118, 5, 3, 0, 0, 118, 119, 3, 4, 2, 0, 119, 120, 5, 4, 0, 0, 120, 122, 1, 0, 0, 0, 121, 113, 1, 0, 0, 0, 121, 116, 1, 0, 0, 0, 122, 131, 1, 0, 0, 0, 123, 124, 10, 4, 0, 0, 124, 125, 7, 0, 0, 0, 125, 130, 3, 14, 7, 0, 126, 127, 10, 3, 0, 0, 127, 128, 7, 1, 0, 0, 128, 130, 3, 14, 7, 0, 129, 123, 1, 0, 0, 0, 129, 126, 1, 0, 0, 0, 130, 133, 1, 0, 0, 0, 131, 129, 1, 0, 0, 0, 131, 132, 1, 0, 0, 0, 132, 13, 1, 0, 0, 0, 133, 131, 1, 0, 0, 0, 134, 135, 6, 7, -1, 0, 135, 136, 5, 3, 0, 0, 136, 137, 3, 4, 2, 0, 137, 138, 5, 4, 0, 0, 138, 141, 1, 0, 0, 0, 139, 141, 3, 8, 4, 0, 140, 134, 1, 0, 0, 0, 140, 139, 1, 0, 0, 0, 141, 147, 1, 0, 0, 0, 142, 143, 10, 3, 0, 0, 143, 144, 7, 0, 0, 0, 144, 146, 3, 14, 7, 4, 145, 142, 1, 0, 0, 0, 146, 149, 1, 0, 0, 0, 147, 145, 1, 0, 0, 0, 147, 148, 1, 0, 0, 0, 148, 15, 1, 0, 0, 0, 149, 147, 1, 0, 0, 0, 12, 54, 81, 92, 94, 102, 107, 111, 121, 129, 131, 140, 147] \ No newline at end of file diff --git a/src/gems/expression/parsing/antlr/ExprParser.py b/src/gems/expression/parsing/antlr/ExprParser.py index 859d23dc..bdc68001 100644 --- a/src/gems/expression/parsing/antlr/ExprParser.py +++ b/src/gems/expression/parsing/antlr/ExprParser.py @@ -16,7 +16,7 @@ def serializedATN(): 4, 1, 18, - 140, + 151, 2, 0, 7, @@ -45,6 +45,10 @@ def serializedATN(): 6, 7, 6, + 2, + 7, + 7, + 7, 1, 0, 1, @@ -121,7 +125,10 @@ def serializedATN(): 2, 1, 2, - 1, + 3, + 2, + 55, + 8, 2, 1, 2, @@ -175,7 +182,7 @@ def serializedATN(): 2, 3, 2, - 79, + 82, 8, 2, 1, @@ -198,79 +205,56 @@ def serializedATN(): 2, 5, 2, - 90, + 93, 8, 2, 10, 2, 12, 2, - 93, + 96, 9, 2, 1, 3, 1, 3, + 1, 3, + 5, 3, - 97, + 101, 8, 3, + 10, + 3, + 12, + 3, + 104, + 9, + 3, 1, 4, 1, 4, 3, 4, - 101, + 108, 8, 4, 1, 5, 1, 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, 3, 5, - 111, + 112, 8, 5, 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, + 6, 1, - 5, - 5, - 5, - 119, - 8, - 5, - 10, - 5, - 12, - 5, - 122, - 9, - 5, + 6, 1, 6, 1, @@ -285,7 +269,7 @@ def serializedATN(): 6, 3, 6, - 130, + 122, 8, 6, 1, @@ -294,26 +278,67 @@ def serializedATN(): 6, 1, 6, + 1, + 6, + 1, + 6, + 1, + 6, 5, 6, - 135, + 130, 8, 6, 10, 6, 12, 6, - 138, + 133, 9, 6, 1, - 6, - 0, + 7, + 1, + 7, + 1, + 7, + 1, + 7, + 1, + 7, + 1, + 7, 3, - 4, + 7, + 141, + 8, + 7, + 1, + 7, + 1, + 7, + 1, + 7, + 5, + 7, + 146, + 8, + 7, 10, + 7, 12, 7, + 149, + 9, + 7, + 1, + 7, + 0, + 3, + 4, + 12, + 14, + 8, 0, 2, 4, @@ -321,6 +346,7 @@ def serializedATN(): 8, 10, 12, + 14, 0, 2, 1, @@ -333,413 +359,413 @@ def serializedATN(): 2, 7, 7, - 153, + 165, 0, - 14, + 16, 1, 0, 0, 0, 2, - 18, + 20, 1, 0, 0, 0, 4, - 78, + 81, 1, 0, 0, 0, 6, - 96, + 97, 1, 0, 0, 0, 8, - 98, + 107, 1, 0, 0, 0, 10, - 110, + 109, 1, 0, 0, 0, 12, - 129, + 121, 1, 0, 0, 0, 14, - 15, + 140, + 1, + 0, + 0, + 0, + 16, + 17, 5, 16, 0, 0, - 15, - 16, + 17, + 18, 5, 1, 0, 0, - 16, - 17, + 18, + 19, 5, 16, 0, 0, - 17, + 19, 1, 1, 0, 0, 0, - 18, - 19, + 20, + 21, 3, 4, 2, 0, - 19, - 20, + 21, + 22, 5, 0, 0, 1, - 20, + 22, 3, 1, 0, 0, 0, - 21, - 22, + 23, + 24, 6, 2, -1, 0, - 22, - 79, - 3, - 6, + 24, + 82, 3, + 8, + 4, 0, - 23, - 79, + 25, + 82, 3, 0, 0, 0, - 24, - 25, + 26, + 27, 5, 2, 0, 0, - 25, - 79, + 27, + 82, 3, 4, 2, 13, - 26, - 27, + 28, + 29, 5, 3, 0, 0, - 27, - 28, + 29, + 30, 3, 4, 2, 0, - 28, - 29, + 30, + 31, 5, 4, 0, 0, - 29, - 79, + 31, + 82, 1, 0, 0, 0, - 30, - 31, + 32, + 33, 5, 8, 0, 0, - 31, - 32, + 33, + 34, 5, 3, 0, 0, - 32, - 33, + 34, + 35, 3, 4, 2, 0, - 33, - 34, - 5, - 4, - 0, - 0, - 34, - 79, - 1, - 0, - 0, - 0, 35, 36, 5, - 9, + 4, 0, 0, 36, - 37, - 5, - 3, + 82, + 1, + 0, 0, 0, 37, 38, - 3, - 0, + 5, + 9, 0, 0, 38, 39, 5, - 4, + 3, 0, 0, 39, - 79, - 1, + 40, + 3, 0, 0, 0, 40, 41, 5, - 8, + 4, 0, 0, 41, - 42, - 5, - 3, + 82, + 1, + 0, 0, 0, 42, 43, - 3, + 5, 8, - 4, + 0, 0, 43, 44, 5, - 10, + 3, 0, 0, 44, 45, 3, - 8, - 4, + 10, + 5, 0, 45, 46, 5, - 11, + 10, 0, 0, 46, 47, 3, - 4, - 2, + 10, + 5, 0, 47, 48, 5, - 4, + 11, 0, 0, 48, - 79, - 1, - 0, - 0, + 49, + 3, + 4, + 2, 0, 49, 50, 5, - 16, + 4, 0, 0, 50, - 51, - 5, - 3, + 82, + 1, + 0, 0, 0, 51, 52, - 3, - 4, - 2, + 5, + 16, + 0, 0, 52, - 53, + 54, 5, - 4, + 3, 0, 0, 53, - 79, + 55, + 3, + 6, + 3, + 0, + 54, + 53, 1, 0, 0, 0, 54, 55, - 5, - 16, + 1, + 0, 0, 0, 55, 56, - 5, - 12, + 1, + 0, 0, 0, 56, - 57, - 3, - 8, + 82, + 5, 4, 0, + 0, 57, 58, 5, - 13, + 16, 0, 0, 58, - 79, - 1, - 0, + 59, + 5, + 12, 0, 0, 59, 60, + 3, + 10, 5, - 16, - 0, 0, 60, 61, 5, - 12, + 13, 0, 0, 61, - 62, - 3, - 4, - 2, + 82, + 1, + 0, + 0, 0, 62, 63, 5, - 13, + 16, 0, 0, 63, - 79, - 1, - 0, - 0, - 0, 64, - 65, 5, - 3, + 12, 0, 0, + 64, 65, - 66, 3, 4, 2, 0, + 65, 66, - 67, 5, - 4, + 13, + 0, + 0, + 66, + 82, + 1, + 0, 0, 0, 67, 68, 5, - 12, + 3, 0, 0, 68, 69, 3, - 8, 4, + 2, 0, 69, 70, 5, - 13, + 4, 0, 0, 70, - 79, - 1, - 0, + 71, + 5, + 12, 0, 0, 71, 72, - 5, 3, - 0, + 10, + 5, 0, 72, 73, - 3, - 4, - 2, + 5, + 13, + 0, 0, 73, - 74, - 5, - 4, + 82, + 1, + 0, 0, 0, 74, 75, 5, - 12, + 3, 0, 0, 75, @@ -751,518 +777,592 @@ def serializedATN(): 76, 77, 5, - 13, + 4, 0, 0, 77, + 78, + 5, + 12, + 0, + 0, + 78, 79, - 1, + 3, + 4, + 2, 0, + 79, + 80, + 5, + 13, 0, 0, - 78, - 21, + 80, + 82, 1, 0, 0, 0, - 78, + 81, 23, 1, 0, 0, 0, - 78, - 24, + 81, + 25, 1, 0, 0, 0, - 78, + 81, 26, 1, 0, 0, 0, - 78, - 30, + 81, + 28, 1, 0, 0, 0, - 78, - 35, + 81, + 32, 1, 0, 0, 0, - 78, - 40, + 81, + 37, 1, 0, 0, 0, - 78, - 49, + 81, + 42, 1, 0, 0, 0, - 78, - 54, + 81, + 51, 1, 0, 0, 0, - 78, - 59, + 81, + 57, 1, 0, 0, 0, - 78, - 64, + 81, + 62, 1, 0, 0, 0, - 78, - 71, + 81, + 67, 1, 0, 0, 0, - 79, - 91, + 81, + 74, 1, 0, 0, 0, - 80, - 81, + 82, + 94, + 1, + 0, + 0, + 0, + 83, + 84, 10, 11, 0, 0, - 81, - 82, + 84, + 85, 7, 0, 0, 0, - 82, - 90, + 85, + 93, 3, 4, 2, 12, - 83, - 84, + 86, + 87, 10, 10, 0, 0, - 84, - 85, + 87, + 88, 7, 1, 0, 0, - 85, - 90, + 88, + 93, 3, 4, 2, 11, - 86, - 87, + 89, + 90, 10, 9, 0, 0, - 87, - 88, + 90, + 91, 5, 17, 0, 0, - 88, - 90, + 91, + 93, 3, 4, 2, 10, - 89, - 80, + 92, + 83, 1, 0, 0, 0, - 89, - 83, + 92, + 86, 1, 0, 0, 0, + 92, 89, - 86, 1, 0, 0, 0, - 90, 93, + 96, 1, 0, 0, 0, - 91, - 89, + 94, + 92, 1, 0, 0, 0, - 91, - 92, + 94, + 95, 1, 0, 0, 0, - 92, + 95, 5, 1, 0, 0, 0, - 93, - 91, + 96, + 94, 1, 0, 0, 0, - 94, 97, + 102, + 3, + 4, + 2, + 0, + 98, + 99, + 5, + 11, + 0, + 0, + 99, + 101, + 3, + 4, + 2, + 0, + 100, + 98, + 1, + 0, + 0, + 0, + 101, + 104, + 1, + 0, + 0, + 0, + 102, + 100, + 1, + 0, + 0, + 0, + 102, + 103, + 1, + 0, + 0, + 0, + 103, + 7, + 1, + 0, + 0, + 0, + 104, + 102, + 1, + 0, + 0, + 0, + 105, + 108, 5, 14, 0, 0, - 95, - 97, + 106, + 108, 5, 16, 0, 0, - 96, - 94, + 107, + 105, 1, 0, 0, 0, - 96, - 95, + 107, + 106, 1, 0, 0, 0, - 97, - 7, + 108, + 9, 1, 0, 0, 0, - 98, - 100, + 109, + 111, 5, 15, 0, 0, - 99, - 101, + 110, + 112, 3, - 10, - 5, + 12, + 6, 0, - 100, - 99, + 111, + 110, 1, 0, 0, 0, - 100, - 101, + 111, + 112, 1, 0, 0, 0, - 101, - 9, + 112, + 11, 1, 0, 0, 0, - 102, - 103, + 113, + 114, + 6, 6, - 5, -1, 0, - 103, - 104, + 114, + 115, 7, 1, 0, 0, - 104, - 111, - 3, - 6, + 115, + 122, 3, + 8, + 4, 0, - 105, - 106, + 116, + 117, 7, 1, 0, 0, - 106, - 107, + 117, + 118, 5, 3, 0, 0, - 107, - 108, + 118, + 119, 3, 4, 2, 0, - 108, - 109, + 119, + 120, 5, 4, 0, 0, - 109, - 111, + 120, + 122, 1, 0, 0, 0, - 110, - 102, + 121, + 113, 1, 0, 0, 0, - 110, - 105, + 121, + 116, 1, 0, 0, 0, - 111, - 120, + 122, + 131, 1, 0, 0, 0, - 112, - 113, + 123, + 124, 10, 4, 0, 0, - 113, - 114, + 124, + 125, 7, 0, 0, 0, - 114, - 119, + 125, + 130, 3, - 12, - 6, + 14, + 7, 0, - 115, - 116, + 126, + 127, 10, 3, 0, 0, - 116, - 117, + 127, + 128, 7, 1, 0, 0, - 117, - 119, + 128, + 130, 3, - 12, - 6, - 0, - 118, - 112, + 14, + 7, + 0, + 129, + 123, 1, 0, 0, 0, - 118, - 115, + 129, + 126, 1, 0, 0, 0, - 119, - 122, + 130, + 133, 1, 0, 0, 0, - 120, - 118, + 131, + 129, 1, 0, 0, 0, - 120, - 121, + 131, + 132, 1, 0, 0, 0, - 121, - 11, + 132, + 13, 1, 0, 0, 0, - 122, - 120, + 133, + 131, 1, 0, 0, 0, - 123, - 124, - 6, + 134, + 135, 6, + 7, -1, 0, - 124, - 125, + 135, + 136, 5, 3, 0, 0, - 125, - 126, + 136, + 137, 3, 4, 2, 0, - 126, - 127, + 137, + 138, 5, 4, 0, 0, - 127, - 130, + 138, + 141, 1, 0, 0, 0, - 128, - 130, - 3, - 6, + 139, + 141, 3, + 8, + 4, 0, - 129, - 123, + 140, + 134, 1, 0, 0, 0, - 129, - 128, + 140, + 139, 1, 0, 0, 0, - 130, - 136, + 141, + 147, 1, 0, 0, 0, - 131, - 132, + 142, + 143, 10, 3, 0, 0, - 132, - 133, + 143, + 144, 7, 0, 0, 0, - 133, - 135, + 144, + 146, 3, - 12, - 6, + 14, + 7, 4, - 134, - 131, + 145, + 142, 1, 0, 0, 0, - 135, - 138, + 146, + 149, 1, 0, 0, 0, - 136, - 134, + 147, + 145, 1, 0, 0, 0, - 136, - 137, + 147, + 148, 1, 0, 0, 0, - 137, - 13, + 148, + 15, 1, 0, 0, 0, - 138, - 136, + 149, + 147, 1, 0, 0, 0, - 10, - 78, - 89, - 91, - 96, - 100, - 110, - 118, - 120, + 12, + 54, + 81, + 92, + 94, + 102, + 107, + 111, + 121, 129, - 136, + 131, + 140, + 147, ] @@ -1319,15 +1419,17 @@ class ExprParser(Parser): RULE_portFieldExpr = 0 RULE_fullexpr = 1 RULE_expr = 2 - RULE_atom = 3 - RULE_shift = 4 - RULE_shift_expr = 5 - RULE_right_expr = 6 + RULE_argList = 3 + RULE_atom = 4 + RULE_shift = 5 + RULE_shift_expr = 6 + RULE_right_expr = 7 ruleNames = [ "portFieldExpr", "fullexpr", "expr", + "argList", "atom", "shift", "shift_expr", @@ -1391,11 +1493,11 @@ def portFieldExpr(self): self.enterRule(localctx, 0, self.RULE_portFieldExpr) try: self.enterOuterAlt(localctx, 1) - self.state = 14 + self.state = 16 self.match(ExprParser.IDENTIFIER) - self.state = 15 + self.state = 17 self.match(ExprParser.T__0) - self.state = 16 + self.state = 18 self.match(ExprParser.IDENTIFIER) except RecognitionException as re: localctx.exception = re @@ -1434,9 +1536,9 @@ def fullexpr(self): self.enterRule(localctx, 2, self.RULE_fullexpr) try: self.enterOuterAlt(localctx, 1) - self.state = 18 + self.state = 20 self.expr(0) - self.state = 19 + self.state = 21 self.match(ExprParser.EOF) except RecognitionException as re: localctx.exception = re @@ -1729,8 +1831,8 @@ def __init__( def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) - def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + def argList(self): + return self.getTypedRuleContext(ExprParser.ArgListContext, 0) def accept(self, visitor: ParseTreeVisitor): if hasattr(visitor, "visitFunction"): @@ -1748,15 +1850,15 @@ def expr(self, _p: int = 0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 78 + self.state = 81 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 0, self._ctx) + la_ = self._interp.adaptivePredict(self._input, 1, self._ctx) if la_ == 1: localctx = ExprParser.UnsignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 22 + self.state = 24 self.atom() pass @@ -1764,7 +1866,7 @@ def expr(self, _p: int = 0): localctx = ExprParser.PortFieldContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 23 + self.state = 25 self.portFieldExpr() pass @@ -1772,9 +1874,9 @@ def expr(self, _p: int = 0): localctx = ExprParser.NegationContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 24 + self.state = 26 self.match(ExprParser.T__1) - self.state = 25 + self.state = 27 self.expr(13) pass @@ -1782,11 +1884,11 @@ def expr(self, _p: int = 0): localctx = ExprParser.ExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 26 + self.state = 28 self.match(ExprParser.T__2) - self.state = 27 + self.state = 29 self.expr(0) - self.state = 28 + self.state = 30 self.match(ExprParser.T__3) pass @@ -1794,13 +1896,13 @@ def expr(self, _p: int = 0): localctx = ExprParser.AllTimeSumContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 30 + self.state = 32 self.match(ExprParser.T__7) - self.state = 31 + self.state = 33 self.match(ExprParser.T__2) - self.state = 32 + self.state = 34 self.expr(0) - self.state = 33 + self.state = 35 self.match(ExprParser.T__3) pass @@ -1808,13 +1910,13 @@ def expr(self, _p: int = 0): localctx = ExprParser.PortFieldSumContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 35 + self.state = 37 self.match(ExprParser.T__8) - self.state = 36 + self.state = 38 self.match(ExprParser.T__2) - self.state = 37 + self.state = 39 self.portFieldExpr() - self.state = 38 + self.state = 40 self.match(ExprParser.T__3) pass @@ -1822,21 +1924,21 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeSumContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 40 + self.state = 42 self.match(ExprParser.T__7) - self.state = 41 + self.state = 43 self.match(ExprParser.T__2) - self.state = 42 + self.state = 44 localctx.from_ = self.shift() - self.state = 43 + self.state = 45 self.match(ExprParser.T__9) - self.state = 44 + self.state = 46 localctx.to = self.shift() - self.state = 45 + self.state = 47 self.match(ExprParser.T__10) - self.state = 46 + self.state = 48 self.expr(0) - self.state = 47 + self.state = 49 self.match(ExprParser.T__3) pass @@ -1844,13 +1946,18 @@ def expr(self, _p: int = 0): localctx = ExprParser.FunctionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 49 - self.match(ExprParser.IDENTIFIER) - self.state = 50 - self.match(ExprParser.T__2) self.state = 51 - self.expr(0) + self.match(ExprParser.IDENTIFIER) self.state = 52 + self.match(ExprParser.T__2) + self.state = 54 + self._errHandler.sync(self) + _la = self._input.LA(1) + if ((_la) & ~0x3F) == 0 and ((1 << _la) & 82700) != 0: + self.state = 53 + self.argList() + + self.state = 56 self.match(ExprParser.T__3) pass @@ -1858,13 +1965,13 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeShiftContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 54 + self.state = 57 self.match(ExprParser.IDENTIFIER) - self.state = 55 + self.state = 58 self.match(ExprParser.T__11) - self.state = 56 + self.state = 59 self.shift() - self.state = 57 + self.state = 60 self.match(ExprParser.T__12) pass @@ -1872,13 +1979,13 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeIndexContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 59 + self.state = 62 self.match(ExprParser.IDENTIFIER) - self.state = 60 + self.state = 63 self.match(ExprParser.T__11) - self.state = 61 + self.state = 64 self.expr(0) - self.state = 62 + self.state = 65 self.match(ExprParser.T__12) pass @@ -1886,17 +1993,17 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeShiftExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 64 + self.state = 67 self.match(ExprParser.T__2) - self.state = 65 + self.state = 68 self.expr(0) - self.state = 66 + self.state = 69 self.match(ExprParser.T__3) - self.state = 67 + self.state = 70 self.match(ExprParser.T__11) - self.state = 68 + self.state = 71 self.shift() - self.state = 69 + self.state = 72 self.match(ExprParser.T__12) pass @@ -1904,32 +2011,32 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeIndexExprContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 71 + self.state = 74 self.match(ExprParser.T__2) - self.state = 72 + self.state = 75 self.expr(0) - self.state = 73 + self.state = 76 self.match(ExprParser.T__3) - self.state = 74 + self.state = 77 self.match(ExprParser.T__11) - self.state = 75 + self.state = 78 self.expr(0) - self.state = 76 + self.state = 79 self.match(ExprParser.T__12) pass self._ctx.stop = self._input.LT(-1) - self.state = 91 + self.state = 94 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 2, self._ctx) + _alt = self._interp.adaptivePredict(self._input, 3, self._ctx) while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: if _alt == 1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 89 + self.state = 92 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 1, self._ctx) + la_ = self._interp.adaptivePredict(self._input, 2, self._ctx) if la_ == 1: localctx = ExprParser.MuldivContext( self, ExprParser.ExprContext(self, _parentctx, _parentState) @@ -1937,14 +2044,14 @@ def expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_expr ) - self.state = 80 + self.state = 83 if not self.precpred(self._ctx, 11): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 11)" ) - self.state = 81 + self.state = 84 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 5 or _la == 6): @@ -1952,7 +2059,7 @@ def expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 82 + self.state = 85 self.expr(12) pass @@ -1963,14 +2070,14 @@ def expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_expr ) - self.state = 83 + self.state = 86 if not self.precpred(self._ctx, 10): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 10)" ) - self.state = 84 + self.state = 87 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -1978,7 +2085,7 @@ def expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 85 + self.state = 88 self.expr(11) pass @@ -1989,22 +2096,22 @@ def expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_expr ) - self.state = 86 + self.state = 89 if not self.precpred(self._ctx, 9): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 9)" ) - self.state = 87 + self.state = 90 self.match(ExprParser.COMPARISON) - self.state = 88 + self.state = 91 self.expr(10) pass - self.state = 93 + self.state = 96 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 2, self._ctx) + _alt = self._interp.adaptivePredict(self._input, 3, self._ctx) except RecognitionException as re: localctx.exception = re @@ -2014,6 +2121,58 @@ def expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx + class ArgListContext(ParserRuleContext): + __slots__ = "parser" + + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i: int = None): + if i is None: + return self.getTypedRuleContexts(ExprParser.ExprContext) + else: + return self.getTypedRuleContext(ExprParser.ExprContext, i) + + def getRuleIndex(self): + return ExprParser.RULE_argList + + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, "visitArgList"): + return visitor.visitArgList(self) + else: + return visitor.visitChildren(self) + + def argList(self): + localctx = ExprParser.ArgListContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_argList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 97 + self.expr(0) + self.state = 102 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la == 11: + self.state = 98 + self.match(ExprParser.T__10) + self.state = 99 + self.expr(0) + self.state = 104 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + class AtomContext(ParserRuleContext): __slots__ = "parser" @@ -2063,21 +2222,21 @@ def accept(self, visitor: ParseTreeVisitor): def atom(self): localctx = ExprParser.AtomContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_atom) + self.enterRule(localctx, 8, self.RULE_atom) try: - self.state = 96 + self.state = 107 self._errHandler.sync(self) token = self._input.LA(1) if token in [14]: localctx = ExprParser.NumberContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 94 + self.state = 105 self.match(ExprParser.NUMBER) pass elif token in [16]: localctx = ExprParser.IdentifierContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 95 + self.state = 106 self.match(ExprParser.IDENTIFIER) pass else: @@ -2117,17 +2276,17 @@ def accept(self, visitor: ParseTreeVisitor): def shift(self): localctx = ExprParser.ShiftContext(self, self._ctx, self.state) - self.enterRule(localctx, 8, self.RULE_shift) + self.enterRule(localctx, 10, self.RULE_shift) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 98 + self.state = 109 self.match(ExprParser.TIME) - self.state = 100 + self.state = 111 self._errHandler.sync(self) _la = self._input.LA(1) if _la == 2 or _la == 7: - self.state = 99 + self.state = 110 self.shift_expr(0) except RecognitionException as re: @@ -2232,20 +2391,20 @@ def shift_expr(self, _p: int = 0): _parentState = self.state localctx = ExprParser.Shift_exprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 10 - self.enterRecursionRule(localctx, 10, self.RULE_shift_expr, _p) + _startState = 12 + self.enterRecursionRule(localctx, 12, self.RULE_shift_expr, _p) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 110 + self.state = 121 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 5, self._ctx) + la_ = self._interp.adaptivePredict(self._input, 7, self._ctx) if la_ == 1: localctx = ExprParser.SignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 103 + self.state = 114 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -2253,7 +2412,7 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 104 + self.state = 115 self.atom() pass @@ -2261,7 +2420,7 @@ def shift_expr(self, _p: int = 0): localctx = ExprParser.SignedExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 105 + self.state = 116 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -2269,26 +2428,26 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 106 + self.state = 117 self.match(ExprParser.T__2) - self.state = 107 + self.state = 118 self.expr(0) - self.state = 108 + self.state = 119 self.match(ExprParser.T__3) pass self._ctx.stop = self._input.LT(-1) - self.state = 120 + self.state = 131 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 7, self._ctx) + _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: if _alt == 1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 118 + self.state = 129 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 6, self._ctx) + la_ = self._interp.adaptivePredict(self._input, 8, self._ctx) if la_ == 1: localctx = ExprParser.ShiftMuldivContext( self, @@ -2299,14 +2458,14 @@ def shift_expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_shift_expr ) - self.state = 112 + self.state = 123 if not self.precpred(self._ctx, 4): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 4)" ) - self.state = 113 + self.state = 124 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 5 or _la == 6): @@ -2314,7 +2473,7 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 114 + self.state = 125 self.right_expr(0) pass @@ -2328,14 +2487,14 @@ def shift_expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_shift_expr ) - self.state = 115 + self.state = 126 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 3)" ) - self.state = 116 + self.state = 127 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 2 or _la == 7): @@ -2343,13 +2502,13 @@ def shift_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 117 + self.state = 128 self.right_expr(0) pass - self.state = 122 + self.state = 133 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 7, self._ctx) + _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) except RecognitionException as re: localctx.exception = re @@ -2431,12 +2590,12 @@ def right_expr(self, _p: int = 0): _parentState = self.state localctx = ExprParser.Right_exprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 12 - self.enterRecursionRule(localctx, 12, self.RULE_right_expr, _p) + _startState = 14 + self.enterRecursionRule(localctx, 14, self.RULE_right_expr, _p) self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 129 + self.state = 140 self._errHandler.sync(self) token = self._input.LA(1) if token in [3]: @@ -2444,27 +2603,27 @@ def right_expr(self, _p: int = 0): self._ctx = localctx _prevctx = localctx - self.state = 124 + self.state = 135 self.match(ExprParser.T__2) - self.state = 125 + self.state = 136 self.expr(0) - self.state = 126 + self.state = 137 self.match(ExprParser.T__3) pass elif token in [14, 16]: localctx = ExprParser.RightAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 128 + self.state = 139 self.atom() pass else: raise NoViableAltException(self) self._ctx.stop = self._input.LT(-1) - self.state = 136 + self.state = 147 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) + _alt = self._interp.adaptivePredict(self._input, 11, self._ctx) while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: if _alt == 1: if self._parseListeners is not None: @@ -2477,14 +2636,14 @@ def right_expr(self, _p: int = 0): self.pushNewRecursionContext( localctx, _startState, self.RULE_right_expr ) - self.state = 131 + self.state = 142 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException( self, "self.precpred(self._ctx, 3)" ) - self.state = 132 + self.state = 143 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not (_la == 5 or _la == 6): @@ -2492,11 +2651,11 @@ def right_expr(self, _p: int = 0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 133 + self.state = 144 self.right_expr(4) - self.state = 138 + self.state = 149 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) + _alt = self._interp.adaptivePredict(self._input, 11, self._ctx) except RecognitionException as re: localctx.exception = re @@ -2510,8 +2669,8 @@ def sempred(self, localctx: RuleContext, ruleIndex: int, predIndex: int): if self._predicates == None: self._predicates = dict() self._predicates[2] = self.expr_sempred - self._predicates[5] = self.shift_expr_sempred - self._predicates[6] = self.right_expr_sempred + self._predicates[6] = self.shift_expr_sempred + self._predicates[7] = self.right_expr_sempred pred = self._predicates.get(ruleIndex, None) if pred is None: raise Exception("No predicate with index:" + str(ruleIndex)) diff --git a/src/gems/expression/parsing/antlr/ExprVisitor.py b/src/gems/expression/parsing/antlr/ExprVisitor.py index a98fbbaf..3a4cf8d6 100644 --- a/src/gems/expression/parsing/antlr/ExprVisitor.py +++ b/src/gems/expression/parsing/antlr/ExprVisitor.py @@ -78,6 +78,10 @@ def visitTimeShift(self, ctx: ExprParser.TimeShiftContext): def visitFunction(self, ctx: ExprParser.FunctionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#argList. + def visitArgList(self, ctx: ExprParser.ArgListContext): + return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#number. def visitNumber(self, ctx: ExprParser.NumberContext): return self.visitChildren(ctx) diff --git a/src/gems/expression/parsing/parse_expression.py b/src/gems/expression/parsing/parse_expression.py index a0bc80e6..d090c872 100644 --- a/src/gems/expression/parsing/parse_expression.py +++ b/src/gems/expression/parsing/parse_expression.py @@ -22,6 +22,8 @@ ComparisonNode, PortFieldAggregatorNode, PortFieldNode, + maximum, + minimum, ) from gems.expression.parsing.antlr.ExprLexer import ExprLexer from gems.expression.parsing.antlr.ExprParser import ExprParser @@ -160,11 +162,23 @@ def visitAllTimeSum(self, ctx: ExprParser.AllTimeSumContext) -> ExpressionNode: # Visit a parse tree produced by ExprParser#function. def visitFunction(self, ctx: ExprParser.FunctionContext) -> ExpressionNode: function_name: str = ctx.IDENTIFIER().getText() # type: ignore - operand: ExpressionNode = ctx.expr().accept(self) # type: ignore - fn = _FUNCTIONS.get(function_name, None) - if fn is None: - raise ValueError(f"Encountered invalid function name {function_name}") - return fn(operand) + arg_list = ctx.argList() # type: ignore + args: list[ExpressionNode] = ( + [expr.accept(self) for expr in arg_list.expr()] # type: ignore + if arg_list is not None + else [] + ) + if function_name in _FUNCTIONS: + if len(args) != 1: + raise ValueError( + f"Function {function_name} requires exactly 1 argument, got {len(args)}" + ) + return _FUNCTIONS[function_name](args[0]) + if function_name == "max": + return maximum(*args) + if function_name == "min": + return minimum(*args) + raise ValueError(f"Encountered invalid function name {function_name}") # Visit a parse tree produced by ExprParser#shift. def visitShift(self, ctx: ExprParser.ShiftContext) -> ExpressionNode: @@ -235,6 +249,8 @@ def visitRightAtom(self, ctx: ExprParser.RightAtomContext) -> ExpressionNode: _FUNCTIONS = { "expec": ExpressionNode.expec, + "floor": ExpressionNode.floor, + "ceil": ExpressionNode.ceil, } diff --git a/src/gems/expression/print.py b/src/gems/expression/print.py index 50308d87..d819185c 100644 --- a/src/gems/expression/print.py +++ b/src/gems/expression/print.py @@ -15,9 +15,13 @@ from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, ExpressionNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -130,6 +134,18 @@ def port_field(self, node: PortFieldNode) -> str: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> str: return f"({visit(node.operand, self)}.{node.aggregator})" + def floor(self, node: FloorNode) -> str: + return f"floor({visit(node.operand, self)})" + + def ceil(self, node: CeilNode) -> str: + return f"ceil({visit(node.operand, self)})" + + def maximum(self, node: MaxNode) -> str: + return "max(" + ", ".join(visit(op, self) for op in node.operands) + ")" + + def minimum(self, node: MinNode) -> str: + return "min(" + ", ".join(visit(op, self) for op in node.operands) + ")" + def print_expr(expression: ExpressionNode) -> str: return visit(expression, PrinterVisitor()) diff --git a/src/gems/expression/visitor.py b/src/gems/expression/visitor.py index 7180b6ab..dc5ad42c 100644 --- a/src/gems/expression/visitor.py +++ b/src/gems/expression/visitor.py @@ -13,6 +13,7 @@ """ Defines abstract base class for visitors of expressions. """ + import typing from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar @@ -20,12 +21,16 @@ from gems.expression.expression import ( AdditionNode, AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, DivisionNode, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, MultiplicationNode, NegationNode, ParameterNode, @@ -128,6 +133,22 @@ def port_field(self, node: PortFieldNode) -> T: def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: ... + @abstractmethod + def floor(self, node: FloorNode) -> T: + ... + + @abstractmethod + def ceil(self, node: CeilNode) -> T: + ... + + @abstractmethod + def maximum(self, node: MaxNode) -> T: + ... + + @abstractmethod + def minimum(self, node: MinNode) -> T: + ... + def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: """ @@ -171,6 +192,14 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: return visitor.port_field(root) elif isinstance(root, PortFieldAggregatorNode): return visitor.port_field_aggregator(root) + elif isinstance(root, FloorNode): + return visitor.floor(root) + elif isinstance(root, CeilNode): + return visitor.ceil(root) + elif isinstance(root, MaxNode): + return visitor.maximum(root) + elif isinstance(root, MinNode): + return visitor.minimum(root) raise ValueError(f"Unknown expression node type {root.__class__}") diff --git a/src/gems/model/common.py b/src/gems/model/common.py index e295c1aa..b4342735 100644 --- a/src/gems/model/common.py +++ b/src/gems/model/common.py @@ -13,6 +13,7 @@ """ Module for common classes used in models. """ + from enum import Enum diff --git a/src/gems/model/model.py b/src/gems/model/model.py index a49e934d..bd5c8b5c 100644 --- a/src/gems/model/model.py +++ b/src/gems/model/model.py @@ -15,6 +15,7 @@ A model allows to define the behaviour for components, by defining parameters, variables, and equations. """ + import itertools from dataclasses import dataclass, field, replace from typing import Any, Dict, Iterable, Optional @@ -179,9 +180,9 @@ def model( return Model( id=id, constraints={c.name: c for c in constraints} if constraints else {}, - binding_constraints={c.name: c for c in binding_constraints} - if binding_constraints - else {}, + binding_constraints=( + {c.name: c for c in binding_constraints} if binding_constraints else {} + ), parameters={p.name: p for p in parameters} if parameters else {}, variables={v.name: v for v in variables} if variables else {}, objective_contributions=objective_contributions, @@ -189,8 +190,10 @@ def model( objective_operational_contribution=objective_operational_contribution, inter_block_dyn=inter_block_dyn, ports=existing_port_names, - port_fields_definitions={d.port_field: d for d in port_fields_definitions} - if port_fields_definitions - else {}, + port_fields_definitions=( + {d.port_field: d for d in port_fields_definitions} + if port_fields_definitions + else {} + ), extra_outputs=extra_outputs, ) diff --git a/src/gems/model/port.py b/src/gems/model/port.py index cb07319b..a13bce35 100644 --- a/src/gems/model/port.py +++ b/src/gems/model/port.py @@ -28,8 +28,12 @@ from gems.expression.expression import ( AllTimeSumNode, BinaryOperatorNode, + CeilNode, ComponentParameterNode, ComponentVariableNode, + FloorNode, + MaxNode, + MinNode, PortFieldAggregatorNode, PortFieldNode, ProblemParameterNode, @@ -165,6 +169,20 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> None: def port_field(self, node: PortFieldNode) -> None: raise ValueError("Port definition cannot reference another port field.") + def floor(self, node: FloorNode) -> None: + visit(node.operand, self) + + def ceil(self, node: CeilNode) -> None: + visit(node.operand, self) + + def maximum(self, node: MaxNode) -> None: + for op in node.operands: + visit(op, self) + + def minimum(self, node: MinNode) -> None: + for op in node.operands: + visit(op, self) + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: raise ValueError("Port definition cannot contain port field aggregation.") diff --git a/src/gems/simulation/linear_expression.py b/src/gems/simulation/linear_expression.py index 1f882e64..a0584c07 100644 --- a/src/gems/simulation/linear_expression.py +++ b/src/gems/simulation/linear_expression.py @@ -14,6 +14,7 @@ Specific modelling for "instantiated" linear expressions, with only variables and literal coefficients. """ + import dataclasses from dataclasses import dataclass from typing import Callable, Dict, List, Optional, TypeVar, Union @@ -118,7 +119,6 @@ def apply(self, other: "TimeExpansion") -> "TimeExpansion": @dataclass(frozen=True) class TermKey: - """ Utility class to provide key for a term that contains all term information except coefficient """ diff --git a/src/gems/simulation/linearize.py b/src/gems/simulation/linearize.py index 7e8dbf15..0d3c9e71 100644 --- a/src/gems/simulation/linearize.py +++ b/src/gems/simulation/linearize.py @@ -23,12 +23,16 @@ ) from gems.expression.expression import ( AllTimeSumNode, + CeilNode, ComparisonNode, ComponentParameterNode, ComponentVariableNode, CurrentScenarioIndex, ExpressionNode, + FloorNode, LiteralNode, + MaxNode, + MinNode, NoScenarioIndex, NoTimeIndex, OneScenarioIndex, @@ -200,6 +204,18 @@ def _get_scenario(self, scenario_index: ScenarioIndex) -> Optional[int]: def literal(self, node: LiteralNode) -> LinearExpressionData: return LinearExpressionData([], node.value) + def floor(self, node: FloorNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a floor operator.") + + def ceil(self, node: CeilNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a ceil operator.") + + def maximum(self, node: MaxNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a max operator.") + + def minimum(self, node: MinNode) -> LinearExpressionData: + raise ValueError("Linear expression cannot contain a min operator.") + def comparison(self, node: ComparisonNode) -> LinearExpressionData: raise ValueError("Linear expression cannot contain a comparison operator.") diff --git a/src/gems/simulation/output_values.py b/src/gems/simulation/output_values.py index a0bbc5f7..e8d4eed7 100644 --- a/src/gems/simulation/output_values.py +++ b/src/gems/simulation/output_values.py @@ -7,6 +7,7 @@ """ Utility classes to obtain solver results. """ + import math from dataclasses import dataclass, field from typing import Any, Dict, Mapping, Optional, Set, TypeVar diff --git a/src/gems/study/network.py b/src/gems/study/network.py index 0ac27ccc..53a53687 100644 --- a/src/gems/study/network.py +++ b/src/gems/study/network.py @@ -14,6 +14,7 @@ The network module defines the data model for an instance of network, including nodes, links, and components (model instantations). """ + import itertools from dataclasses import dataclass, field, replace from typing import Any, Dict, Iterable, List, cast diff --git a/src/gems/utils.py b/src/gems/utils.py index 23c1ebaa..7964fd24 100644 --- a/src/gems/utils.py +++ b/src/gems/utils.py @@ -13,6 +13,7 @@ """ Module for technical utilities. """ + import json import pathlib from typing import Any, Callable, Dict, Optional, TypeVar diff --git a/tests/e2e/functional/libs/standard.py b/tests/e2e/functional/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/e2e/functional/libs/standard.py +++ b/tests/e2e/functional/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure diff --git a/tests/e2e/functional/test_libs_python_system_python.py b/tests/e2e/functional/test_libs_python_system_python.py index aea1abff..01a7bebe 100644 --- a/tests/e2e/functional/test_libs_python_system_python.py +++ b/tests/e2e/functional/test_libs_python_system_python.py @@ -33,6 +33,7 @@ - Description: Short-term storage behavior over different horizons and efficiencies. - Names: `test_short_test_horizon_10`, `test_short_test_horizon_5`. """ + import pandas as pd import pytest diff --git a/tests/e2e/integration/libs/standard.py b/tests/e2e/integration/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/e2e/integration/libs/standard.py +++ b/tests/e2e/integration/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure diff --git a/tests/e2e/models/poc-various-models/libs/standard.py b/tests/e2e/models/poc-various-models/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/e2e/models/poc-various-models/libs/standard.py +++ b/tests/e2e/models/poc-various-models/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure diff --git a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py index 0d8d7505..1663ff40 100644 --- a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py +++ b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs.py @@ -129,9 +129,9 @@ def test_electrolyzer_n_inputs_1() -> None: print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -218,9 +218,9 @@ def test_electrolyzer_n_inputs_2() -> None: print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -313,9 +313,9 @@ def test_electrolyzer_n_inputs_3() -> None: ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) @@ -401,9 +401,9 @@ def test_electrolyzer_n_inputs_4() -> None: ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) diff --git a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py index 2e007571..78bfc5a0 100644 --- a/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py +++ b/tests/e2e/models/poc-various-models/test_electrolyzer_n_inputs_yaml.py @@ -135,9 +135,9 @@ def test_electrolyzer_n_inputs_1( print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -235,9 +235,9 @@ def test_electrolyzer_n_inputs_2( print(ep2_gen) print(gp_gen) - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 42) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 42) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1990) @@ -342,9 +342,9 @@ def test_electrolyzer_n_inputs_3( ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) @@ -443,9 +443,9 @@ def test_electrolyzer_n_inputs_4( ep2_gen = output.component("ep2").var("generation").value gp_gen = output.component("gp").var("generation").value - assert math.isclose(ep1_gen, 70) # type:ignore - assert math.isclose(ep2_gen, 30) # type:ignore - assert math.isclose(gp_gen, 30) # type:ignore + assert math.isclose(ep1_gen, 70) # type: ignore + assert math.isclose(ep2_gen, 30) # type: ignore + assert math.isclose(gp_gen, 30) # type: ignore assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 1750) diff --git a/tests/e2e/models/poc-various-models/test_quota_co2.py b/tests/e2e/models/poc-various-models/test_quota_co2.py index 2bd504ee..ee3bc445 100644 --- a/tests/e2e/models/poc-various-models/test_quota_co2.py +++ b/tests/e2e/models/poc-various-models/test_quota_co2.py @@ -85,6 +85,6 @@ def test_quota_co2() -> None: assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 5500) - assert math.isclose(oil1_p, 50) # type:ignore - assert math.isclose(coal1_p, 50) # type:ignore - assert math.isclose(l12_flow, -50) # type:ignore + assert math.isclose(oil1_p, 50) # type: ignore + assert math.isclose(coal1_p, 50) # type: ignore + assert math.isclose(l12_flow, -50) # type: ignore diff --git a/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py b/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py index e0163cc7..812a75cc 100644 --- a/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py +++ b/tests/e2e/models/poc-various-models/test_quota_co2_yaml.py @@ -91,6 +91,6 @@ def test_quota_co2( assert status == problem.solver.OPTIMAL assert math.isclose(problem.solver.Objective().Value(), 5500) - assert math.isclose(oil1_p, 50) # type:ignore - assert math.isclose(coal1_p, 50) # type:ignore - assert math.isclose(l12_flow, -50) # type:ignore + assert math.isclose(oil1_p, 50) # type: ignore + assert math.isclose(coal1_p, 50) # type: ignore + assert math.isclose(l12_flow, -50) # type: ignore diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index d6558d7f..7287bed3 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -15,7 +15,7 @@ from gems.expression import ExpressionNode, literal, param, print_expr, var from gems.expression.equality import expressions_equal -from gems.expression.expression import port_field +from gems.expression.expression import maximum, minimum, port_field from gems.expression.parsing.parse_expression import ( AntaresParseException, ModelIdentifiers, @@ -119,6 +119,54 @@ "expec(sum(cost * generation))", (param("cost") * var("generation")).time_sum().expec(), ), + ( + {}, + {"p"}, + "floor(p)", + param("p").floor(), + ), + ( + {}, + {"p"}, + "ceil(p)", + param("p").ceil(), + ), + ( + {}, + {"a", "b"}, + "max(a, b)", + maximum(param("a"), param("b")), + ), + ( + {}, + {"a", "b"}, + "min(a, b)", + minimum(param("a"), param("b")), + ), + ( + {}, + {"p", "q"}, + "ceil(p/q)", + (param("p") / param("q")).ceil(), + ), + ( + {}, + {"p", "q"}, + "max(0, ceil(p/q))", + maximum(literal(0), (param("p") / param("q")).ceil()), + ), + ( + {}, + {"a", "b", "c"}, + "max(a, b, c)", + maximum(param("a"), param("b"), param("c")), + ), + ( + {}, + {"a", "b", "c"}, + "min(a, b, c)", + minimum(param("a"), param("b"), param("c")), + ), ], ) def test_parsing_visitor( diff --git a/tests/unittests/expressions/visitor/test_degree.py b/tests/unittests/expressions/visitor/test_degree.py index 72996d64..e2b3b492 100644 --- a/tests/unittests/expressions/visitor/test_degree.py +++ b/tests/unittests/expressions/visitor/test_degree.py @@ -10,9 +10,20 @@ # # This file is part of the Antares project. +import math + import pytest -from gems.expression import ExpressionDegreeVisitor, LiteralNode, param, var, visit +from gems.expression import ( + ExpressionDegreeVisitor, + LiteralNode, + maximum, + minimum, + param, + var, + visit, +) +from gems.expression.expression import CeilNode, FloorNode def test_degree() -> None: @@ -26,6 +37,33 @@ def test_degree() -> None: assert visit(expr, ExpressionDegreeVisitor()) == 2 +def test_floor_ceil_degree() -> None: + x = var("x") + p = param("p") + + assert visit(FloorNode(p), ExpressionDegreeVisitor()) == 0 + assert visit(CeilNode(p), ExpressionDegreeVisitor()) == 0 + assert visit(FloorNode(x), ExpressionDegreeVisitor()) == math.inf + assert visit(CeilNode(x), ExpressionDegreeVisitor()) == math.inf + + +def test_max_min_degree() -> None: + x = var("x") + p = param("p") + q = param("q") + + assert visit(maximum(p, q), ExpressionDegreeVisitor()) == 0 + assert visit(minimum(p, q), ExpressionDegreeVisitor()) == 0 + assert visit(maximum(x, p), ExpressionDegreeVisitor()) == math.inf + assert visit(minimum(p, x), ExpressionDegreeVisitor()) == math.inf + assert visit(maximum(x, x), ExpressionDegreeVisitor()) == math.inf + # variadic (3+ operands) + assert visit(maximum(p, q, param("r")), ExpressionDegreeVisitor()) == 0 + assert visit(minimum(p, q, param("r")), ExpressionDegreeVisitor()) == 0 + assert visit(maximum(p, q, x), ExpressionDegreeVisitor()) == math.inf + assert visit(minimum(p, x, q), ExpressionDegreeVisitor()) == math.inf + + @pytest.mark.xfail(reason="Degree simplification not implemented") def test_degree_computation_should_take_into_account_simplifications() -> None: x = var("x") diff --git a/tests/unittests/expressions/visitor/test_equality.py b/tests/unittests/expressions/visitor/test_equality.py index bcdec8ff..b2efd0fb 100644 --- a/tests/unittests/expressions/visitor/test_equality.py +++ b/tests/unittests/expressions/visitor/test_equality.py @@ -14,6 +14,7 @@ from gems.expression import ExpressionNode, copy_expression, literal, param, var from gems.expression.equality import expressions_equal +from gems.expression.expression import maximum, minimum @pytest.mark.parametrize( @@ -30,6 +31,10 @@ var("x").time_sum(), var("x") + 5 <= 2, var("x").expec(), + var("x").floor(), + var("x").ceil(), + maximum(var("x"), param("p")), + minimum(var("x"), param("p")), ], ) def test_equals(expr: ExpressionNode) -> None: @@ -52,6 +57,14 @@ def test_equals(expr: ExpressionNode) -> None: var("x").time_sum(1, 10), ), (var("x").expec(), var("y").expec()), + # floor / ceil + (var("x").floor(), var("y").floor()), + (var("x").ceil(), var("y").ceil()), + (var("x").floor(), var("x").ceil()), # different node type + # max / min + (maximum(var("x"), param("p")), maximum(var("y"), param("p"))), + (minimum(var("x"), param("p")), minimum(var("x"), param("q"))), + (maximum(var("x"), param("p")), minimum(var("x"), param("p"))), # Max vs Min ], ) def test_not_equals(lhs: ExpressionNode, rhs: ExpressionNode) -> None: diff --git a/tests/unittests/expressions/visitor/test_evaluation.py b/tests/unittests/expressions/visitor/test_evaluation.py index 62e25183..efd110ff 100644 --- a/tests/unittests/expressions/visitor/test_evaluation.py +++ b/tests/unittests/expressions/visitor/test_evaluation.py @@ -111,3 +111,29 @@ def test_sum_expressions() -> None: assert expressions_equal( sum_expressions([literal(1), var("x"), param("p")]), 1 + (var("x") + param("p")) ) + + +def test_floor_ceil_max_min() -> None: + from gems.expression.expression import maximum, minimum + + context = EvaluationContext(parameters={"p": 2.7, "q": 1.3}) + + assert visit(param("p").floor(), EvaluationVisitor(context)) == 2.0 + assert visit(param("p").ceil(), EvaluationVisitor(context)) == 3.0 + assert visit( + maximum(param("p"), param("q")), EvaluationVisitor(context) + ) == pytest.approx(2.7) + assert visit( + minimum(param("p"), param("q")), EvaluationVisitor(context) + ) == pytest.approx(1.3) + assert visit( + maximum(literal(0), param("q")), EvaluationVisitor(context) + ) == pytest.approx(1.3) + assert visit(maximum(literal(0), -param("p")), EvaluationVisitor(context)) == 0.0 + # variadic (3+ operands) + assert visit( + maximum(param("p"), param("q"), literal(5.0)), EvaluationVisitor(context) + ) == pytest.approx(5.0) + assert visit( + minimum(param("p"), param("q"), literal(5.0)), EvaluationVisitor(context) + ) == pytest.approx(1.3) diff --git a/tests/unittests/expressions/visitor/test_printer.py b/tests/unittests/expressions/visitor/test_printer.py index a41250af..34510760 100644 --- a/tests/unittests/expressions/visitor/test_printer.py +++ b/tests/unittests/expressions/visitor/test_printer.py @@ -19,3 +19,23 @@ def test_comparison() -> None: expr: ExpressionNode = (5 * x + 3) >= p - 2 assert visit(expr, PrinterVisitor()) == "((5.0 * x) + 3.0) >= (p - 2.0)" + + +def test_floor_ceil_max_min_printer() -> None: + from gems.expression.expression import maximum, minimum + + p = param("p") + q = param("q") + + assert visit(p.floor(), PrinterVisitor()) == "floor(p)" + assert visit(p.ceil(), PrinterVisitor()) == "ceil(p)" + assert visit(maximum(p, q), PrinterVisitor()) == "max(p, q)" + assert visit(minimum(p, q), PrinterVisitor()) == "min(p, q)" + assert visit((p / q).ceil(), PrinterVisitor()) == "ceil((p / q))" + assert ( + visit(maximum(param("a"), (p / q).ceil()), PrinterVisitor()) + == "max(a, ceil((p / q)))" + ) + # variadic (3+ operands) + assert visit(maximum(p, q, param("r")), PrinterVisitor()) == "max(p, q, r)" + assert visit(minimum(p, q, param("r")), PrinterVisitor()) == "min(p, q, r)" diff --git a/tests/unittests/lib_parsing/test_objective_contribution.py b/tests/unittests/lib_parsing/test_objective_contribution.py index 49908930..18b48991 100644 --- a/tests/unittests/lib_parsing/test_objective_contribution.py +++ b/tests/unittests/lib_parsing/test_objective_contribution.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MPL-2.0 """Tests for objective creation logic and coefficient accumulation.""" + from typing import Any, Dict, Optional from unittest.mock import Mock, patch @@ -97,7 +98,7 @@ def name(self) -> str: def _setup_mock_optimization_environment( - linear_expressions: Dict[Any, LinearExpression] + linear_expressions: Dict[Any, LinearExpression], ) -> Any: """Sets up a mock context and problem with real OR-Tools solver objects.""" diff --git a/tests/unittests/system/libs/standard.py b/tests/unittests/system/libs/standard.py index bd30e958..a8f6cd7f 100644 --- a/tests/unittests/system/libs/standard.py +++ b/tests/unittests/system/libs/standard.py @@ -13,6 +13,7 @@ """ The standard module contains the definition of standard models. """ + from gems.expression import literal, param, var from gems.expression.expression import port_field from gems.expression.indexing_structure import IndexingStructure