diff --git a/src/gems/expression/evaluate.py b/src/gems/expression/evaluate.py index 2c850eff..eb67240d 100644 --- a/src/gems/expression/evaluate.py +++ b/src/gems/expression/evaluate.py @@ -13,7 +13,7 @@ import math from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, Iterable from gems.expression.expression import ( AllTimeSumNode, @@ -51,20 +51,25 @@ class ValueProvider(ABC): """ @abstractmethod - def get_variable_value(self, name: str) -> float: - ... + def get_variable_value(self, name: str) -> float: ... @abstractmethod - def get_parameter_value(self, name: str) -> float: - ... + def get_parameter_value(self, name: str) -> float: ... @abstractmethod - def get_component_variable_value(self, component_id: str, name: str) -> float: - ... + def get_component_variable_value(self, component_id: str, name: str) -> float: ... @abstractmethod - def get_component_parameter_value(self, component_id: str, name: str) -> float: - ... + def get_component_parameter_value(self, component_id: str, name: str) -> float: ... + + @abstractmethod + def shift(self, offset: int) -> "ValueProvider": ... + + @abstractmethod + def eval_at(self, timestep: int) -> "ValueProvider": ... + + @abstractmethod + def all_block_timesteps(self) -> Iterable[int]: ... @dataclass(frozen=True) @@ -89,6 +94,15 @@ def get_component_variable_value(self, component_id: str, name: str) -> float: def get_component_parameter_value(self, component_id: str, name: str) -> float: raise NotImplementedError() + def shift(self, offset: int) -> "ValueProvider": + return self + + def eval_at(self, timestep: int) -> "ValueProvider": + return self + + def all_block_timesteps(self) -> Iterable[int]: + return range(1) + @dataclass(frozen=True) class EvaluationVisitor(ExpressionVisitorOperations[float]): @@ -124,16 +138,26 @@ def pb_variable(self, node: ProblemVariableNode) -> float: raise ValueError("Should not reach here.") def time_shift(self, node: TimeShiftNode) -> float: - raise NotImplementedError() + shift = int(visit(node.time_shift, self)) + return visit(node.operand, EvaluationVisitor(self.context.shift(shift))) def time_eval(self, node: TimeEvalNode) -> float: - raise NotImplementedError() + timestep = int(visit(node.eval_time, self)) + return visit(node.operand, EvaluationVisitor(self.context.eval_at(timestep))) def time_sum(self, node: TimeSumNode) -> float: - raise NotImplementedError() + from_shift = int(visit(node.from_time, self)) + to_shift = int(visit(node.to_time, self)) + return sum( + visit(node.operand, EvaluationVisitor(self.context.shift(t))) + for t in range(from_shift, to_shift + 1) + ) def all_time_sum(self, node: AllTimeSumNode) -> float: - raise NotImplementedError() + return sum( + visit(node.operand, EvaluationVisitor(self.context.eval_at(t))) + for t in self.context.all_block_timesteps() + ) def scenario_operator(self, node: ScenarioOperatorNode) -> float: raise NotImplementedError() diff --git a/src/gems/simulation/extra_output.py b/src/gems/simulation/extra_output.py index eede420f..49ff9019 100644 --- a/src/gems/simulation/extra_output.py +++ b/src/gems/simulation/extra_output.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional from gems.expression.evaluate import ValueProvider from gems.simulation.optimization import OptimizationProblem @@ -91,3 +91,12 @@ def get_component_variable_value(self, component_id: str, name: str) -> float: def get_component_parameter_value(self, component_id: str, name: str) -> float: return self.context[name] + + def shift(self, offset: int) -> "ValueProvider": + return self + + def eval_at(self, timestep: int) -> "ValueProvider": + return self + + def all_block_timesteps(self) -> Iterable[int]: + return range(1) diff --git a/src/gems/simulation/optimization.py b/src/gems/simulation/optimization.py index 1d43b1de..74526248 100644 --- a/src/gems/simulation/optimization.py +++ b/src/gems/simulation/optimization.py @@ -103,6 +103,18 @@ def get_parameter_value(self, name: str) -> float: "Parameter must be associated to its component before resolution." ) + def shift(self, offset: int) -> ValueProvider: + new_timestep = ( + (block_timestep + offset) if block_timestep is not None else None + ) + return _make_value_provider(context, new_timestep, scenario) + + def eval_at(self, timestep: int) -> ValueProvider: + return _make_value_provider(context, timestep, scenario) + + def all_block_timesteps(self) -> Iterable[int]: + return range(context.block_length()) + return Impl() @@ -415,6 +427,15 @@ def get_parameter_value(self, name: str) -> float: "Parameter must be associated to its component before resolution." ) + def shift(self, offset: int) -> ValueProvider: + return self + + def eval_at(self, timestep: int) -> ValueProvider: + return self + + def all_block_timesteps(self) -> Iterable[int]: + return range(1) + return Impl() @@ -889,11 +910,9 @@ def fusion_problems( root_var = root_vars[var.name()] root_var.SetLb(var.lb()) root_var.SetUb(var.ub()) - root_master.context._solver_variables[ - var.name() - ].is_in_objective = context._solver_variables[ - var.name() - ].is_in_objective + root_master.context._solver_variables[var.name()].is_in_objective = ( + context._solver_variables[var.name()].is_in_objective + ) for cstr in master.solver.constraints(): coeff = cstr.GetCoefficient(var) diff --git a/tests/unittests/expressions/visitor/test_evaluation.py b/tests/unittests/expressions/visitor/test_evaluation.py index efd110ff..ac9e13e5 100644 --- a/tests/unittests/expressions/visitor/test_evaluation.py +++ b/tests/unittests/expressions/visitor/test_evaluation.py @@ -11,7 +11,7 @@ # This file is part of the Antares project. from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, Iterable import pytest @@ -68,6 +68,15 @@ def get_component_variable_value(self, component_id: str, name: str) -> float: def get_component_parameter_value(self, component_id: str, name: str) -> float: return self.parameters[comp_key(component_id, name)] + def shift(self, offset: int) -> "ValueProvider": + return self + + def eval_at(self, timestep: int) -> "ValueProvider": + return self + + def all_block_timesteps(self) -> Iterable[int]: + return range(1) + def test_comp_parameter() -> None: add_node = AdditionNode([LiteralNode(1), ComponentVariableNode("comp1", "x")])