Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ bin
outputs
build
src/gems.egg-info/
src/gemspy.egg-info/
4 changes: 3 additions & 1 deletion grammar/Expr.g4
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ expr
| 'sum' '(' expr ')' # allTimeSum
| 'sum_connections' '(' portFieldExpr ')' # portFieldSum
| 'sum' '(' from=shift '..' to=shift ',' expr ')' # timeSum
| IDENTIFIER '(' expr ')' # function
| IDENTIFIER '(' argList? ')' # function
| IDENTIFIER '[' shift ']' # timeShift
| IDENTIFIER '[' expr ']' # timeIndex
| '(' expr ')' '[' shift ']' # timeShiftExpr
| '(' expr ')' '[' expr ']' # timeIndexExpr
;

argList : expr (',' expr)* ;

atom
: NUMBER # number
| IDENTIFIER # identifier
Expand Down
6 changes: 6 additions & 0 deletions src/gems/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
)
from .expression import (
AdditionNode,
CeilNode,
Comparator,
ComparisonNode,
DivisionNode,
ExpressionNode,
FloorNode,
LiteralNode,
MaxNode,
MinNode,
MultiplicationNode,
NegationNode,
ParameterNode,
VariableNode,
literal,
maximum,
minimum,
param,
sum_expressions,
var,
Expand Down
16 changes: 16 additions & 0 deletions src/gems/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@

from .expression import (
AllTimeSumNode,
CeilNode,
ComparisonNode,
ComponentParameterNode,
ComponentVariableNode,
ExpressionNode,
FloorNode,
LiteralNode,
MaxNode,
MinNode,
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
Expand Down Expand Up @@ -95,6 +99,18 @@ def port_field(self, node: PortFieldNode) -> ExpressionNode:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode:
return PortFieldAggregatorNode(visit(node.operand, self), node.aggregator)

def floor(self, node: FloorNode) -> ExpressionNode:
return FloorNode(visit(node.operand, self))

def ceil(self, node: CeilNode) -> ExpressionNode:
return CeilNode(visit(node.operand, self))

def maximum(self, node: MaxNode) -> ExpressionNode:
return MaxNode([visit(op, self) for op in node.operands])

def minimum(self, node: MinNode) -> ExpressionNode:
return MinNode([visit(op, self) for op in node.operands])


def copy_expression(expression: ExpressionNode) -> ExpressionNode:
return visit(expression, CopyVisitor())
62 changes: 41 additions & 21 deletions src/gems/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
#
# This file is part of the Antares project.

import math

import gems.expression.scenario_operator
from gems.expression.expression import (
AllTimeSumNode,
CeilNode,
ComponentParameterNode,
ComponentVariableNode,
FloorNode,
MaxNode,
MinNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
Expand All @@ -39,77 +45,91 @@
from .visitor import ExpressionVisitor, T, visit


class ExpressionDegreeVisitor(ExpressionVisitor[int]):
class ExpressionDegreeVisitor(ExpressionVisitor[int | float]):
"""
Computes degree of expression with respect to variables.
"""

def literal(self, node: LiteralNode) -> int:
def literal(self, node: LiteralNode) -> int | float:
return 0

def negation(self, node: NegationNode) -> int:
def negation(self, node: NegationNode) -> int | float:
return visit(node.operand, self)

# TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div
def addition(self, node: AdditionNode) -> int:
def addition(self, node: AdditionNode) -> int | float:
degrees = [visit(o, self) for o in node.operands]
return max(degrees)

def multiplication(self, node: MultiplicationNode) -> int:
def multiplication(self, node: MultiplicationNode) -> int | float:
return visit(node.left, self) + visit(node.right, self)

def division(self, node: DivisionNode) -> int:
def division(self, node: DivisionNode) -> int | float:
right_degree = visit(node.right, self)
if right_degree != 0:
raise ValueError("Degree computation not implemented for divisions.")
return visit(node.left, self)

def comparison(self, node: ComparisonNode) -> int:
def comparison(self, node: ComparisonNode) -> int | float:
return max(visit(node.left, self), visit(node.right, self))

def variable(self, node: VariableNode) -> int:
def variable(self, node: VariableNode) -> int | float:
return 1

def parameter(self, node: ParameterNode) -> int:
def parameter(self, node: ParameterNode) -> int | float:
return 0

def comp_variable(self, node: ComponentVariableNode) -> int:
def comp_variable(self, node: ComponentVariableNode) -> int | float:
return 1

def comp_parameter(self, node: ComponentParameterNode) -> int:
def comp_parameter(self, node: ComponentParameterNode) -> int | float:
return 0

def pb_variable(self, node: ProblemVariableNode) -> int:
def pb_variable(self, node: ProblemVariableNode) -> int | float:
return 1

def pb_parameter(self, node: ProblemParameterNode) -> int:
def pb_parameter(self, node: ProblemParameterNode) -> int | float:
return 0

def time_shift(self, node: TimeShiftNode) -> int:
def time_shift(self, node: TimeShiftNode) -> int | float:
return visit(node.operand, self)

def time_eval(self, node: TimeEvalNode) -> int:
def time_eval(self, node: TimeEvalNode) -> int | float:
return visit(node.operand, self)

def time_sum(self, node: TimeSumNode) -> int:
def time_sum(self, node: TimeSumNode) -> int | float:
return visit(node.operand, self)

def all_time_sum(self, node: AllTimeSumNode) -> int:
def all_time_sum(self, node: AllTimeSumNode) -> int | float:
return visit(node.operand, self)

def scenario_operator(self, node: ScenarioOperatorNode) -> int:
def scenario_operator(self, node: ScenarioOperatorNode) -> int | float:
scenario_operator_cls = getattr(gems.expression.scenario_operator, node.name)
# TODO: Carefully check if this formula is correct
return scenario_operator_cls.degree() * visit(node.operand, self)

def port_field(self, node: PortFieldNode) -> int:
def port_field(self, node: PortFieldNode) -> int | float:
return 1

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int | float:
return visit(node.operand, self)

def floor(self, node: FloorNode) -> int | float:
d = visit(node.operand, self)
return 0 if d == 0 else math.inf

def ceil(self, node: CeilNode) -> int | float:
d = visit(node.operand, self)
return 0 if d == 0 else math.inf

def maximum(self, node: MaxNode) -> int | float:
return 0 if all(visit(op, self) == 0 for op in node.operands) else math.inf

def minimum(self, node: MinNode) -> int | float:
return 0 if all(visit(op, self) == 0 for op in node.operands) else math.inf


def compute_degree(expression: ExpressionNode) -> int:
def compute_degree(expression: ExpressionNode) -> int | float:
return visit(expression, ExpressionDegreeVisitor())


Expand Down
28 changes: 28 additions & 0 deletions src/gems/expression/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
from gems.expression.expression import (
AllTimeSumNode,
BinaryOperatorNode,
CeilNode,
ComponentParameterNode,
ComponentVariableNode,
FloorNode,
MaxNode,
MinNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
Expand Down Expand Up @@ -111,6 +115,14 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool:
right, PortFieldAggregatorNode
):
return self.port_field_aggregator(left, right)
if isinstance(left, FloorNode) and isinstance(right, FloorNode):
return self.floor(left, right)
if isinstance(left, CeilNode) and isinstance(right, CeilNode):
return self.ceil(left, right)
if isinstance(left, MaxNode) and isinstance(right, MaxNode):
return self.maximum(left, right)
if isinstance(left, MinNode) and isinstance(right, MinNode):
return self.minimum(left, right)
raise NotImplementedError(f"Equality not implemented for {left.__class__}")

def literal(self, left: LiteralNode, right: LiteralNode) -> bool:
Expand Down Expand Up @@ -215,6 +227,22 @@ def port_field_aggregator(
left.operand, right.operand
)

def floor(self, left: FloorNode, right: FloorNode) -> bool:
return self.visit(left.operand, right.operand)

def ceil(self, left: CeilNode, right: CeilNode) -> bool:
return self.visit(left.operand, right.operand)

def maximum(self, left: MaxNode, right: MaxNode) -> bool:
return len(left.operands) == len(right.operands) and all(
self.visit(l, r) for l, r in zip(left.operands, right.operands)
)

def minimum(self, left: MinNode, right: MinNode) -> bool:
return len(left.operands) == len(right.operands) and all(
self.visit(l, r) for l, r in zip(left.operands, right.operands)
)


def expressions_equal(
left: ExpressionNode, right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0
Expand Down
17 changes: 17 additions & 0 deletions src/gems/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
#
# This file is part of the Antares project.

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict

from gems.expression.expression import (
AllTimeSumNode,
CeilNode,
ComponentParameterNode,
ComponentVariableNode,
FloorNode,
MaxNode,
MinNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
Expand Down Expand Up @@ -139,6 +144,18 @@ def port_field(self, node: PortFieldNode) -> float:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float:
raise NotImplementedError()

def floor(self, node: FloorNode) -> float:
return float(math.floor(visit(node.operand, self)))

def ceil(self, node: CeilNode) -> float:
return float(math.ceil(visit(node.operand, self)))

def maximum(self, node: MaxNode) -> float:
return max(visit(op, self) for op in node.operands)

def minimum(self, node: MinNode) -> float:
return min(visit(op, self) for op in node.operands)


def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float:
return visit(expression, EvaluationVisitor(value_provider))
Expand Down
35 changes: 35 additions & 0 deletions src/gems/expression/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""
Defines the model for generic expressions.
"""

import enum
import inspect
from dataclasses import dataclass
Expand Down Expand Up @@ -117,6 +118,12 @@ def shift(self, shift: AnyExpression) -> "ExpressionNode":
def eval(self, time: AnyExpression) -> "ExpressionNode":
return TimeEvalNode(self, _wrap_in_node(time))

def floor(self) -> "ExpressionNode":
return FloorNode(self)

def ceil(self) -> "ExpressionNode":
return CeilNode(self)

def expec(self) -> "ExpressionNode":
return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation"))

Expand Down Expand Up @@ -368,6 +375,16 @@ class NegationNode(UnaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class FloorNode(UnaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class CeilNode(UnaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class BinaryOperatorNode(ExpressionNode):
left: ExpressionNode
Expand Down Expand Up @@ -400,6 +417,24 @@ class DivisionNode(BinaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class MaxNode(ExpressionNode):
operands: List[ExpressionNode]


@dataclass(frozen=True, eq=False)
class MinNode(ExpressionNode):
operands: List[ExpressionNode]


def maximum(*operands: "ExpressionNode") -> "MaxNode":
return MaxNode(list(operands))


def minimum(*operands: "ExpressionNode") -> "MinNode":
return MinNode(list(operands))


@dataclass(frozen=True, eq=False)
class TimeShiftNode(UnaryOperatorNode):
time_shift: ExpressionNode
Expand Down
16 changes: 16 additions & 0 deletions src/gems/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from .expression import (
AdditionNode,
AllTimeSumNode,
CeilNode,
ComparisonNode,
ComponentParameterNode,
ComponentVariableNode,
DivisionNode,
ExpressionNode,
FloorNode,
LiteralNode,
MaxNode,
MinNode,
MultiplicationNode,
NegationNode,
ParameterNode,
Expand Down Expand Up @@ -159,6 +163,18 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStruct
"Port fields aggregators must be resolved before computing indexing structure."
)

def floor(self, node: FloorNode) -> IndexingStructure:
return visit(node.operand, self)

def ceil(self, node: CeilNode) -> IndexingStructure:
return visit(node.operand, self)

def maximum(self, node: MaxNode) -> IndexingStructure:
return self._combine(node.operands)

def minimum(self, node: MinNode) -> IndexingStructure:
return self._combine(node.operands)


def compute_indexation(
expression: ExpressionNode, provider: IndexingStructureProvider
Expand Down
Loading
Loading