Skip to content

Commit

Permalink
fix(types): projectors
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Sep 17, 2024
1 parent 3601736 commit bc48e5a
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 50 deletions.
4 changes: 2 additions & 2 deletions openfisca_core/projectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#
# See: https://www.python.org/dev/peps/pep-0008/#imports

from . import typing
from . import types
from .entity_to_person_projector import EntityToPersonProjector
from .first_person_to_entity_projector import FirstPersonToEntityProjector
from .helpers import get_projector_from_shortcut, projectable
Expand All @@ -35,5 +35,5 @@
"projectable",
"Projector",
"UniqueRoleToEntityProjector",
"typing",
"types",
]
15 changes: 7 additions & 8 deletions openfisca_core/projectors/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

from collections.abc import Mapping

from openfisca_core.types import GroupEntity, Role, SingleEntity
from collections.abc import Iterable, Mapping

from openfisca_core import entities, projectors

from .typing import GroupPopulation, Population
from .types import GroupEntity, GroupPopulation, Role, SingleEntity, SinglePopulation


def projectable(function):
Expand All @@ -19,7 +17,7 @@ def projectable(function):


def get_projector_from_shortcut(
population: Population | GroupPopulation,
population: SinglePopulation | GroupPopulation,
shortcut: str,
parent: projectors.Projector | None = None,
) -> projectors.Projector | None:
Expand All @@ -46,7 +44,7 @@ def get_projector_from_shortcut(
of a specific Simulation and TaxBenefitSystem.
Args:
population (Population | GroupPopulation): Where to project from.
population (SinglePopulation | GroupPopulation): Where to project from.
shortcut (str): Where to project to.
parent: ???
Expand Down Expand Up @@ -114,7 +112,7 @@ def get_projector_from_shortcut(

if isinstance(entity, entities.Entity):
populations: Mapping[
str, Population | GroupPopulation
str, SinglePopulation | GroupPopulation
] = population.simulation.populations

if shortcut not in populations.keys():
Expand All @@ -126,7 +124,8 @@ def get_projector_from_shortcut(
return projectors.FirstPersonToEntityProjector(population, parent)

if isinstance(entity, entities.GroupEntity):
role: Role | None = entities.find_role(entity.roles, shortcut, total=1)
roles: Iterable[Role] = entity.roles
role: Role | None = entities.find_role(roles, shortcut, total=1)

if role is not None:
return projectors.UniqueRoleToEntityProjector(population, role, parent)
Expand Down
52 changes: 52 additions & 0 deletions openfisca_core/projectors/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Protocol

from openfisca_core import types as t

# Entities


class SingleEntity(t.SingleEntity, Protocol):
...


class GroupEntity(t.GroupEntity, Protocol):
...


class Role(t.Role, Protocol):
...


# Populations


class SinglePopulation(t.SinglePopulation, Protocol):
@property
def entity(self) -> t.SingleEntity:
...

@property
def simulation(self) -> Simulation:
...


class GroupPopulation(t.GroupPopulation, Protocol):
@property
def entity(self) -> t.GroupEntity:
...

@property
def simulation(self) -> Simulation:
...


# Simulations


class Simulation(t.Simulation, Protocol):
@property
def populations(self) -> Mapping[str, SinglePopulation | GroupPopulation]:
...
32 changes: 0 additions & 32 deletions openfisca_core/projectors/typing.py

This file was deleted.

14 changes: 8 additions & 6 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Dict, Mapping, NamedTuple, Optional, Set

from openfisca_core.types import Population, TaxBenefitSystem, Variable
from openfisca_core.types import SinglePopulation, TaxBenefitSystem, Variable

import tempfile
import warnings
Expand All @@ -19,13 +19,13 @@ class Simulation:
"""

tax_benefit_system: TaxBenefitSystem
populations: Dict[str, Population]
populations: Dict[str, SinglePopulation]
invalidated_caches: Set[Cache]

def __init__(
self,
tax_benefit_system: TaxBenefitSystem,
populations: Mapping[str, Population],
populations: Mapping[str, SinglePopulation],
):
"""
This constructor is reserved for internal use; see :any:`SimulationBuilder`,
Expand Down Expand Up @@ -531,7 +531,7 @@ def set_input(self, variable_name: str, period, value):
return
self.get_holder(variable_name).set_input(period, value)

def get_variable_population(self, variable_name: str) -> Population:
def get_variable_population(self, variable_name: str) -> SinglePopulation:
variable: Optional[Variable]

variable = self.tax_benefit_system.get_variable(
Expand All @@ -543,7 +543,9 @@ def get_variable_population(self, variable_name: str) -> Population:

return self.populations[variable.entity.key]

def get_population(self, plural: Optional[str] = None) -> Optional[Population]:
def get_population(
self, plural: Optional[str] = None
) -> Optional[SinglePopulation]:
return next(
(
population
Expand All @@ -556,7 +558,7 @@ def get_population(self, plural: Optional[str] = None) -> Optional[Population]:
def get_entity(
self,
plural: Optional[str] = None,
) -> Optional[Population]:
) -> Optional[SinglePopulation]:
population = self.get_population(plural)
return population and population.entity

Expand Down
12 changes: 10 additions & 2 deletions openfisca_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,21 @@ def unit(self) -> DateUnit:
# Populations


class Population(Protocol):
class CorePopulation(Protocol):
...


class SinglePopulation(CorePopulation, Protocol):
entity: Any

def get_holder(self, variable_name: Any) -> Any:
...


class GroupPopulation(CorePopulation, Protocol):
...


# Simulations


Expand Down Expand Up @@ -163,7 +171,7 @@ class Variable(Protocol):
class Formula(Protocol):
def __call__(
self,
population: Population,
population: GroupPopulation,
instant: Instant,
params: Params,
) -> Array[Any]:
Expand Down
1 change: 1 addition & 0 deletions openfisca_tasks/lint.mk
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ check-types:
openfisca_core/commons \
openfisca_core/entities \
openfisca_core/periods \
openfisca_core/projectors \
openfisca_core/types.py
@$(call print_pass,$@:)

Expand Down

0 comments on commit bc48e5a

Please sign in to comment.