Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pyomo/common/pyomo_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ def _get_fullqual_name(func: typing.Callable) -> str:
return f"{func.__module__}.{func.__qualname__}"


def overload(func: typing.Callable):
"""Wrap typing.overload that remembers the overloaded signatures
if typing.TYPE_CHECKING:
from typing import overload as overload
else:

This provides a custom implementation of typing.overload that
remembers the overloaded signatures so that they are available for
runtime inspection.
def overload(func: typing.Callable):
"""Wrap typing.overload that remembers the overloaded signatures

"""
_overloads.setdefault(_get_fullqual_name(func), []).append(func)
return typing.overload(func)
This provides a custom implementation of typing.overload that
remembers the overloaded signatures so that they are available for
runtime inspection.

"""
_overloads.setdefault(_get_fullqual_name(func), []).append(func)
return typing.overload(func)


def get_overloads_for(func: typing.Callable):
Expand Down
5 changes: 5 additions & 0 deletions pyomo/contrib/appsi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import weakref

from typing import (
TYPE_CHECKING,
Sequence,
Dict,
Optional,
Expand Down Expand Up @@ -1714,5 +1715,9 @@ class LegacySolver(LegacySolverInterface, cls):

return decorator

if TYPE_CHECKING:
# NOTE: `Factory.__call__` can return None, but for the common case
def __call__(self, name, **kwds) -> Solver: ...


SolverFactory = SolverFactoryClass()
10 changes: 9 additions & 1 deletion pyomo/contrib/solver/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# ___________________________________________________________________________


from pyomo.opt.base.solvers import LegacySolverFactory
from typing import TYPE_CHECKING

from pyomo.common.factory import Factory
from pyomo.contrib.solver.common.base import LegacySolverWrapper
from pyomo.opt.base.solvers import LegacySolverFactory


class SolverFactoryClass(Factory):
Expand Down Expand Up @@ -107,6 +109,12 @@ class LegacySolver(LegacySolverWrapper, cls):

return decorator

if TYPE_CHECKING:
from pyomo.contrib.solver.common.base import SolverBase

# NOTE: `Factory.__call__` can return None, but for the common case
def __call__(self, name, **kwds) -> SolverBase: ...


#: Global registry/factory for "v2" solver interfaces.
SolverFactory: SolverFactoryClass = SolverFactoryClass()
9 changes: 7 additions & 2 deletions pyomo/core/base/PyomoModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from weakref import ref as weakref_ref
import gc
import math
from typing import TypeVar

from pyomo.common import timing
from pyomo.common.collections import Bunch
Expand Down Expand Up @@ -572,6 +573,10 @@ def select(
StaleFlagManager.mark_all_as_stale(delayed=True)


# NOTE: Python 3.11+ use `typing.Self`
ModelT = TypeVar("ModelT", bound="Model")


@ModelComponentFactory.register(
'Model objects can be used as a component of other models.'
)
Expand All @@ -583,9 +588,9 @@ class Model(ScalarBlock):

_Block_reserved_words = set()

def __new__(cls, *args, **kwds):
def __new__(cls: type[ModelT], *args, **kwds) -> ModelT:
if cls != Model:
return super(Model, cls).__new__(cls)
return super(Model, cls).__new__(cls) # type: ignore

raise TypeError(
"Directly creating the 'Model' class is not allowed. Please use the "
Expand Down
10 changes: 5 additions & 5 deletions pyomo/core/base/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,17 +2085,17 @@ class Block(ActiveIndexedComponent):
_ComponentDataClass = BlockData
_private_data_initializers = defaultdict(lambda: dict)

@overload
def __new__(
cls: Type[Block], *args, **kwds
) -> Union[ScalarBlock, IndexedBlock]: ...

@overload
def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock: ...

@overload
def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock: ...

@overload
def __new__(
cls: Type[Block], *args, **kwds
) -> Union[ScalarBlock, IndexedBlock]: ...

def __new__(cls, *args, **kwds):
if cls != Block:
return super(Block, cls).__new__(cls)
Expand Down
10 changes: 5 additions & 5 deletions pyomo/core/base/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,17 +638,17 @@ class Constraint(ActiveIndexedComponent):
Violated = Infeasible
Satisfied = Feasible

@overload
def __new__(
cls: Type[Constraint], *args, **kwds
) -> Union[ScalarConstraint, IndexedConstraint]: ...

@overload
def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint: ...

@overload
def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint: ...

@overload
def __new__(
cls: Type[Constraint], *args, **kwds
) -> Union[ScalarConstraint, IndexedConstraint]: ...

def __new__(cls, *args, **kwds):
if cls != Constraint:
return super().__new__(cls)
Expand Down
10 changes: 5 additions & 5 deletions pyomo/core/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,17 @@ class NoValue:

pass

@overload
def __new__(
cls: Type[Param], *args, **kwds
) -> Union[ScalarParam, IndexedParam]: ...

@overload
def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam: ...

@overload
def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam: ...

@overload
def __new__(
cls: Type[Param], *args, **kwds
) -> Union[ScalarParam, IndexedParam]: ...

def __new__(cls, *args, **kwds):
if cls != Param:
return super(Param, cls).__new__(cls)
Expand Down
4 changes: 2 additions & 2 deletions pyomo/core/base/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,10 +2128,10 @@ class SortedOrder:
_UnorderedInitializers = {set}

@overload
def __new__(cls: Type[Set], *args, **kwds) -> Union[SetData, IndexedSet]: ...
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ...

@overload
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ...
def __new__(cls: Type[Set], *args, **kwds) -> Union[SetData, IndexedSet]: ...

def __new__(cls, *args, **kwds):
if cls is not Set:
Expand Down
6 changes: 3 additions & 3 deletions pyomo/core/base/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,15 +575,15 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin):

_ComponentDataClass = VarData

@overload
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ...

@overload
def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar: ...

@overload
def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar: ...

@overload
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ...

def __new__(cls, *args, **kwargs):
if cls is not Var:
return super(Var, cls).__new__(cls)
Expand Down
35 changes: 29 additions & 6 deletions pyomo/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from typing import TYPE_CHECKING, Any, Literal, overload

import pyomo.environ as _environ

if TYPE_CHECKING:
import pyomo.contrib.appsi.base as _appsi
import pyomo.contrib.solver.common.factory as _contrib
import pyomo.opt.base.solvers as _solvers

__doc__ = """
Preview capabilities through ``pyomo.__future__``
=================================================
Expand All @@ -28,13 +35,29 @@

"""

solver_factory_v1: "_solvers.SolverFactoryClass"
solver_factory_v2: "_appsi.SolverFactoryClass"
solver_factory_v3: "_contrib.SolverFactoryClass"


def __getattr__(name):
if name in ('solver_factory_v1', 'solver_factory_v2', 'solver_factory_v3'):
if name in ("solver_factory_v1", "solver_factory_v2", "solver_factory_v3"):
return solver_factory(int(name[-1]))
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


@overload
def solver_factory(version: None = None) -> int: ...
@overload
def solver_factory(version: Literal[1]) -> "_solvers.SolverFactoryClass": ...
@overload
def solver_factory(version: Literal[2]) -> "_appsi.SolverFactoryClass": ...
@overload
def solver_factory(version: Literal[3]) -> "_contrib.SolverFactoryClass": ...
@overload
def solver_factory(version: int) -> Any: ...


def solver_factory(version=None):
"""Get (or set) the active implementation of the SolverFactory

Expand Down Expand Up @@ -90,19 +113,19 @@ def solver_factory(version=None):
if current is None:
for ver, cls in versions.items():
if cls._cls is _environ.SolverFactory._cls:
solver_factory._active_version = ver
solver_factory._active_version = ver # type: ignore
break
return solver_factory._active_version
return solver_factory._active_version # type: ignore
#
# The user is just asking what the current SolverFactory is; tell them.
if version is None:
return solver_factory._active_version
return solver_factory._active_version # type: ignore
#
# Update the current SolverFactory to be a shim around (shallow copy
# of) the new active factory
src = versions.get(version, None)
if version is not None:
solver_factory._active_version = version
solver_factory._active_version = version # type: ignore
for attr in ('_description', '_cls', '_doc'):
setattr(_environ.SolverFactory, attr, getattr(src, attr))
else:
Expand All @@ -113,4 +136,4 @@ def solver_factory(version=None):
return src


solver_factory._active_version = solver_factory()
solver_factory._active_version = solver_factory() # type: ignore
6 changes: 6 additions & 0 deletions pyomo/opt/base/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import time
import logging
import shlex
from typing import overload

from pyomo.common import Factory
from pyomo.common.enums import SolverAPIVersion
Expand Down Expand Up @@ -144,6 +145,11 @@ def _solver_error(self, method_name):


class SolverFactoryClass(Factory):
@overload
def __call__(self, _name: None = None, **kwds) -> "SolverFactoryClass": ...
@overload
def __call__(self, _name, **kwds) -> "OptSolver": ...

def __call__(self, _name=None, **kwds):
if _name is None:
return self
Expand Down
Loading