Skip to content
Closed
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ bin
outputs
build
src/gems.egg-info/
src/gemspy.egg-info/
1 change: 1 addition & 0 deletions grammar/Expr.g4
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ expr
| 'sum_connections' '(' portFieldExpr ')' # portFieldSum
| 'sum' '(' from=shift '..' to=shift ',' expr ')' # timeSum
| IDENTIFIER '(' expr ')' # function
| IDENTIFIER '(' expr ',' expr ')' # binaryFunction
| IDENTIFIER '[' shift ']' # timeShift
| IDENTIFIER '[' expr ']' # timeIndex
| '(' expr ')' '[' shift ']' # timeShiftExpr
Expand Down
6 changes: 6 additions & 0 deletions src/gems/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
)
from .expression import (
AdditionNode,
CeilNode,
Comparator,
ComparisonNode,
DivisionNode,
ExpressionNode,
FloorNode,
LiteralNode,
MaxNode,
MinNode,
MultiplicationNode,
NegationNode,
ParameterNode,
VariableNode,
literal,
maximum,
minimum,
param,
sum_expressions,
var,
Expand Down
16 changes: 16 additions & 0 deletions src/gems/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@

from .expression import (
AllTimeSumNode,
CeilNode,
ComparisonNode,
ComponentParameterNode,
ComponentVariableNode,
ExpressionNode,
FloorNode,
LiteralNode,
MaxNode,
MinNode,
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
Expand Down Expand Up @@ -95,6 +99,18 @@ def port_field(self, node: PortFieldNode) -> ExpressionNode:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode:
return PortFieldAggregatorNode(visit(node.operand, self), node.aggregator)

def floor(self, node: FloorNode) -> ExpressionNode:
return FloorNode(visit(node.operand, self))

def ceil(self, node: CeilNode) -> ExpressionNode:
return CeilNode(visit(node.operand, self))

def maximum(self, node: MaxNode) -> ExpressionNode:
return MaxNode(visit(node.left, self), visit(node.right, self))

def minimum(self, node: MinNode) -> ExpressionNode:
return MinNode(visit(node.left, self), visit(node.right, self))


def copy_expression(expression: ExpressionNode) -> ExpressionNode:
return visit(expression, CopyVisitor())
66 changes: 45 additions & 21 deletions src/gems/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
#
# This file is part of the Antares project.

import math

import gems.expression.scenario_operator
from gems.expression.expression import (
AllTimeSumNode,
CeilNode,
ComponentParameterNode,
ComponentVariableNode,
FloorNode,
MaxNode,
MinNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
Expand All @@ -39,77 +45,95 @@
from .visitor import ExpressionVisitor, T, visit


class ExpressionDegreeVisitor(ExpressionVisitor[int]):
class ExpressionDegreeVisitor(ExpressionVisitor[int | float]):
"""
Computes degree of expression with respect to variables.
"""

def literal(self, node: LiteralNode) -> int:
def literal(self, node: LiteralNode) -> int | float:
return 0

def negation(self, node: NegationNode) -> int:
def negation(self, node: NegationNode) -> int | float:
return visit(node.operand, self)

# TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div
def addition(self, node: AdditionNode) -> int:
def addition(self, node: AdditionNode) -> int | float:
degrees = [visit(o, self) for o in node.operands]
return max(degrees)

def multiplication(self, node: MultiplicationNode) -> int:
def multiplication(self, node: MultiplicationNode) -> int | float:
return visit(node.left, self) + visit(node.right, self)

def division(self, node: DivisionNode) -> int:
def division(self, node: DivisionNode) -> int | float:
right_degree = visit(node.right, self)
if right_degree != 0:
raise ValueError("Degree computation not implemented for divisions.")
return visit(node.left, self)

def comparison(self, node: ComparisonNode) -> int:
def comparison(self, node: ComparisonNode) -> int | float:
return max(visit(node.left, self), visit(node.right, self))

def variable(self, node: VariableNode) -> int:
def variable(self, node: VariableNode) -> int | float:
return 1

def parameter(self, node: ParameterNode) -> int:
def parameter(self, node: ParameterNode) -> int | float:
return 0

def comp_variable(self, node: ComponentVariableNode) -> int:
def comp_variable(self, node: ComponentVariableNode) -> int | float:
return 1

def comp_parameter(self, node: ComponentParameterNode) -> int:
def comp_parameter(self, node: ComponentParameterNode) -> int | float:
return 0

def pb_variable(self, node: ProblemVariableNode) -> int:
def pb_variable(self, node: ProblemVariableNode) -> int | float:
return 1

def pb_parameter(self, node: ProblemParameterNode) -> int:
def pb_parameter(self, node: ProblemParameterNode) -> int | float:
return 0

def time_shift(self, node: TimeShiftNode) -> int:
def time_shift(self, node: TimeShiftNode) -> int | float:
return visit(node.operand, self)

def time_eval(self, node: TimeEvalNode) -> int:
def time_eval(self, node: TimeEvalNode) -> int | float:
return visit(node.operand, self)

def time_sum(self, node: TimeSumNode) -> int:
def time_sum(self, node: TimeSumNode) -> int | float:
return visit(node.operand, self)

def all_time_sum(self, node: AllTimeSumNode) -> int:
def all_time_sum(self, node: AllTimeSumNode) -> int | float:
return visit(node.operand, self)

def scenario_operator(self, node: ScenarioOperatorNode) -> int:
def scenario_operator(self, node: ScenarioOperatorNode) -> int | float:
scenario_operator_cls = getattr(gems.expression.scenario_operator, node.name)
# TODO: Carefully check if this formula is correct
return scenario_operator_cls.degree() * visit(node.operand, self)

def port_field(self, node: PortFieldNode) -> int:
def port_field(self, node: PortFieldNode) -> int | float:
return 1

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int | float:
return visit(node.operand, self)

def floor(self, node: FloorNode) -> int | float:
d = visit(node.operand, self)
return 0 if d == 0 else math.inf

def ceil(self, node: CeilNode) -> int | float:
d = visit(node.operand, self)
return 0 if d == 0 else math.inf

def maximum(self, node: MaxNode) -> int | float:
d_l = visit(node.left, self)
d_r = visit(node.right, self)
return 0 if d_l == 0 and d_r == 0 else math.inf

def minimum(self, node: MinNode) -> int | float:
d_l = visit(node.left, self)
d_r = visit(node.right, self)
return 0 if d_l == 0 and d_r == 0 else math.inf


def compute_degree(expression: ExpressionNode) -> int:
def compute_degree(expression: ExpressionNode) -> int | float:
return visit(expression, ExpressionDegreeVisitor())


Expand Down
24 changes: 24 additions & 0 deletions src/gems/expression/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
from gems.expression.expression import (
AllTimeSumNode,
BinaryOperatorNode,
CeilNode,
ComponentParameterNode,
ComponentVariableNode,
FloorNode,
MaxNode,
MinNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
Expand Down Expand Up @@ -111,6 +115,14 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool:
right, PortFieldAggregatorNode
):
return self.port_field_aggregator(left, right)
if isinstance(left, FloorNode) and isinstance(right, FloorNode):
return self.floor(left, right)
if isinstance(left, CeilNode) and isinstance(right, CeilNode):
return self.ceil(left, right)
if isinstance(left, MaxNode) and isinstance(right, MaxNode):
return self.maximum(left, right)
if isinstance(left, MinNode) and isinstance(right, MinNode):
return self.minimum(left, right)
raise NotImplementedError(f"Equality not implemented for {left.__class__}")

def literal(self, left: LiteralNode, right: LiteralNode) -> bool:
Expand Down Expand Up @@ -215,6 +227,18 @@ def port_field_aggregator(
left.operand, right.operand
)

def floor(self, left: FloorNode, right: FloorNode) -> bool:
return self.visit(left.operand, right.operand)

def ceil(self, left: CeilNode, right: CeilNode) -> bool:
return self.visit(left.operand, right.operand)

def maximum(self, left: MaxNode, right: MaxNode) -> bool:
return self._visit_operands(left, right)

def minimum(self, left: MinNode, right: MinNode) -> bool:
return self._visit_operands(left, right)


def expressions_equal(
left: ExpressionNode, right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0
Expand Down
17 changes: 17 additions & 0 deletions src/gems/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
#
# This file is part of the Antares project.

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict

from gems.expression.expression import (
AllTimeSumNode,
CeilNode,
ComponentParameterNode,
ComponentVariableNode,
FloorNode,
MaxNode,
MinNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
Expand Down Expand Up @@ -139,6 +144,18 @@ def port_field(self, node: PortFieldNode) -> float:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float:
raise NotImplementedError()

def floor(self, node: FloorNode) -> float:
return float(math.floor(visit(node.operand, self)))

def ceil(self, node: CeilNode) -> float:
return float(math.ceil(visit(node.operand, self)))

def maximum(self, node: MaxNode) -> float:
return max(visit(node.left, self), visit(node.right, self))

def minimum(self, node: MinNode) -> float:
return min(visit(node.left, self), visit(node.right, self))


def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float:
return visit(expression, EvaluationVisitor(value_provider))
Expand Down
35 changes: 35 additions & 0 deletions src/gems/expression/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""
Defines the model for generic expressions.
"""

import enum
import inspect
from dataclasses import dataclass
Expand Down Expand Up @@ -117,6 +118,12 @@ def shift(self, shift: AnyExpression) -> "ExpressionNode":
def eval(self, time: AnyExpression) -> "ExpressionNode":
return TimeEvalNode(self, _wrap_in_node(time))

def floor(self) -> "ExpressionNode":
return FloorNode(self)

def ceil(self) -> "ExpressionNode":
return CeilNode(self)

def expec(self) -> "ExpressionNode":
return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation"))

Expand Down Expand Up @@ -368,6 +375,16 @@ class NegationNode(UnaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class FloorNode(UnaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class CeilNode(UnaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class BinaryOperatorNode(ExpressionNode):
left: ExpressionNode
Expand Down Expand Up @@ -400,6 +417,24 @@ class DivisionNode(BinaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class MaxNode(BinaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class MinNode(BinaryOperatorNode):
pass


def maximum(left: "ExpressionNode", right: "ExpressionNode") -> "MaxNode":
return MaxNode(left, right)


def minimum(left: "ExpressionNode", right: "ExpressionNode") -> "MinNode":
return MinNode(left, right)


@dataclass(frozen=True, eq=False)
class TimeShiftNode(UnaryOperatorNode):
time_shift: ExpressionNode
Expand Down
16 changes: 16 additions & 0 deletions src/gems/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from .expression import (
AdditionNode,
AllTimeSumNode,
CeilNode,
ComparisonNode,
ComponentParameterNode,
ComponentVariableNode,
DivisionNode,
ExpressionNode,
FloorNode,
LiteralNode,
MaxNode,
MinNode,
MultiplicationNode,
NegationNode,
ParameterNode,
Expand Down Expand Up @@ -159,6 +163,18 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStruct
"Port fields aggregators must be resolved before computing indexing structure."
)

def floor(self, node: FloorNode) -> IndexingStructure:
return visit(node.operand, self)

def ceil(self, node: CeilNode) -> IndexingStructure:
return visit(node.operand, self)

def maximum(self, node: MaxNode) -> IndexingStructure:
return self._combine([node.left, node.right])

def minimum(self, node: MinNode) -> IndexingStructure:
return self._combine([node.left, node.right])


def compute_indexation(
expression: ExpressionNode, provider: IndexingStructureProvider
Expand Down
Loading
Loading