diff --git a/pyomo/common/pyomo_typing.py b/pyomo/common/pyomo_typing.py index 35f3567432a..be4631e97cd 100644 --- a/pyomo/common/pyomo_typing.py +++ b/pyomo/common/pyomo_typing.py @@ -18,16 +18,20 @@ def _get_fullqual_name(func: typing.Callable) -> str: return f"{func.__module__}.{func.__qualname__}" -def overload(func: typing.Callable): - """Wrap typing.overload that remembers the overloaded signatures +if typing.TYPE_CHECKING: + from typing import overload as overload +else: - This provides a custom implementation of typing.overload that - remembers the overloaded signatures so that they are available for - runtime inspection. + def overload(func: typing.Callable): + """Wrap typing.overload that remembers the overloaded signatures - """ - _overloads.setdefault(_get_fullqual_name(func), []).append(func) - return typing.overload(func) + This provides a custom implementation of typing.overload that + remembers the overloaded signatures so that they are available for + runtime inspection. + + """ + _overloads.setdefault(_get_fullqual_name(func), []).append(func) + return typing.overload(func) def get_overloads_for(func: typing.Callable): diff --git a/pyomo/contrib/appsi/base.py b/pyomo/contrib/appsi/base.py index 181cb5e28ab..44536c57c03 100644 --- a/pyomo/contrib/appsi/base.py +++ b/pyomo/contrib/appsi/base.py @@ -16,6 +16,7 @@ import weakref from typing import ( + TYPE_CHECKING, Sequence, Dict, Optional, @@ -1714,5 +1715,9 @@ class LegacySolver(LegacySolverInterface, cls): return decorator + if TYPE_CHECKING: + # NOTE: `Factory.__call__` can return None, but for the common case + def __call__(self, name, **kwds) -> Solver: ... + SolverFactory = SolverFactoryClass() diff --git a/pyomo/contrib/solver/common/factory.py b/pyomo/contrib/solver/common/factory.py index 8af5de6ab9c..a982b05477a 100644 --- a/pyomo/contrib/solver/common/factory.py +++ b/pyomo/contrib/solver/common/factory.py @@ -10,9 +10,11 @@ # ___________________________________________________________________________ -from pyomo.opt.base.solvers import LegacySolverFactory +from typing import TYPE_CHECKING + from pyomo.common.factory import Factory from pyomo.contrib.solver.common.base import LegacySolverWrapper +from pyomo.opt.base.solvers import LegacySolverFactory class SolverFactoryClass(Factory): @@ -107,6 +109,12 @@ class LegacySolver(LegacySolverWrapper, cls): return decorator + if TYPE_CHECKING: + from pyomo.contrib.solver.common.base import SolverBase + + # NOTE: `Factory.__call__` can return None, but for the common case + def __call__(self, name, **kwds) -> SolverBase: ... + #: Global registry/factory for "v2" solver interfaces. SolverFactory: SolverFactoryClass = SolverFactoryClass() diff --git a/pyomo/core/base/PyomoModel.py b/pyomo/core/base/PyomoModel.py index c2da52eed5f..accdba29148 100644 --- a/pyomo/core/base/PyomoModel.py +++ b/pyomo/core/base/PyomoModel.py @@ -14,6 +14,7 @@ from weakref import ref as weakref_ref import gc import math +from typing import TypeVar from pyomo.common import timing from pyomo.common.collections import Bunch @@ -572,6 +573,10 @@ def select( StaleFlagManager.mark_all_as_stale(delayed=True) +# NOTE: Python 3.11+ use `typing.Self` +ModelT = TypeVar("ModelT", bound="Model") + + @ModelComponentFactory.register( 'Model objects can be used as a component of other models.' ) @@ -583,9 +588,9 @@ class Model(ScalarBlock): _Block_reserved_words = set() - def __new__(cls, *args, **kwds): + def __new__(cls: type[ModelT], *args, **kwds) -> ModelT: if cls != Model: - return super(Model, cls).__new__(cls) + return super(Model, cls).__new__(cls) # type: ignore raise TypeError( "Directly creating the 'Model' class is not allowed. Please use the " diff --git a/pyomo/core/base/block.py b/pyomo/core/base/block.py index 97b2dee721b..851eda9d74c 100644 --- a/pyomo/core/base/block.py +++ b/pyomo/core/base/block.py @@ -2085,17 +2085,17 @@ class Block(ActiveIndexedComponent): _ComponentDataClass = BlockData _private_data_initializers = defaultdict(lambda: dict) - @overload - def __new__( - cls: Type[Block], *args, **kwds - ) -> Union[ScalarBlock, IndexedBlock]: ... - @overload def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock: ... @overload def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock: ... + @overload + def __new__( + cls: Type[Block], *args, **kwds + ) -> Union[ScalarBlock, IndexedBlock]: ... + def __new__(cls, *args, **kwds): if cls != Block: return super(Block, cls).__new__(cls) diff --git a/pyomo/core/base/constraint.py b/pyomo/core/base/constraint.py index 970c393425b..8f15060c93e 100644 --- a/pyomo/core/base/constraint.py +++ b/pyomo/core/base/constraint.py @@ -638,17 +638,17 @@ class Constraint(ActiveIndexedComponent): Violated = Infeasible Satisfied = Feasible - @overload - def __new__( - cls: Type[Constraint], *args, **kwds - ) -> Union[ScalarConstraint, IndexedConstraint]: ... - @overload def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint: ... @overload def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint: ... + @overload + def __new__( + cls: Type[Constraint], *args, **kwds + ) -> Union[ScalarConstraint, IndexedConstraint]: ... + def __new__(cls, *args, **kwds): if cls != Constraint: return super().__new__(cls) diff --git a/pyomo/core/base/param.py b/pyomo/core/base/param.py index 02ba103cae3..d6406f2ba98 100644 --- a/pyomo/core/base/param.py +++ b/pyomo/core/base/param.py @@ -309,17 +309,17 @@ class NoValue: pass - @overload - def __new__( - cls: Type[Param], *args, **kwds - ) -> Union[ScalarParam, IndexedParam]: ... - @overload def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam: ... @overload def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam: ... + @overload + def __new__( + cls: Type[Param], *args, **kwds + ) -> Union[ScalarParam, IndexedParam]: ... + def __new__(cls, *args, **kwds): if cls != Param: return super(Param, cls).__new__(cls) diff --git a/pyomo/core/base/set.py b/pyomo/core/base/set.py index 8c858f96ba8..4f62c822310 100644 --- a/pyomo/core/base/set.py +++ b/pyomo/core/base/set.py @@ -2128,10 +2128,10 @@ class SortedOrder: _UnorderedInitializers = {set} @overload - def __new__(cls: Type[Set], *args, **kwds) -> Union[SetData, IndexedSet]: ... + def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ... @overload - def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ... + def __new__(cls: Type[Set], *args, **kwds) -> Union[SetData, IndexedSet]: ... def __new__(cls, *args, **kwds): if cls is not Set: diff --git a/pyomo/core/base/var.py b/pyomo/core/base/var.py index 6b9b5fb4151..9bb9555ae10 100644 --- a/pyomo/core/base/var.py +++ b/pyomo/core/base/var.py @@ -575,15 +575,15 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin): _ComponentDataClass = VarData - @overload - def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ... - @overload def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar: ... @overload def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar: ... + @overload + def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ... + def __new__(cls, *args, **kwargs): if cls is not Var: return super(Var, cls).__new__(cls) diff --git a/pyomo/future.py b/pyomo/future.py index 2d8718f5f47..94168324d3c 100644 --- a/pyomo/future.py +++ b/pyomo/future.py @@ -9,8 +9,15 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +from typing import TYPE_CHECKING, Any, Literal, overload + import pyomo.environ as _environ +if TYPE_CHECKING: + import pyomo.contrib.appsi.base as _appsi + import pyomo.contrib.solver.common.factory as _contrib + import pyomo.opt.base.solvers as _solvers + __doc__ = """ Preview capabilities through ``pyomo.__future__`` ================================================= @@ -28,13 +35,29 @@ """ +solver_factory_v1: "_solvers.SolverFactoryClass" +solver_factory_v2: "_appsi.SolverFactoryClass" +solver_factory_v3: "_contrib.SolverFactoryClass" + def __getattr__(name): - if name in ('solver_factory_v1', 'solver_factory_v2', 'solver_factory_v3'): + if name in ("solver_factory_v1", "solver_factory_v2", "solver_factory_v3"): return solver_factory(int(name[-1])) raise AttributeError(f"module '{__name__}' has no attribute '{name}'") +@overload +def solver_factory(version: None = None) -> int: ... +@overload +def solver_factory(version: Literal[1]) -> "_solvers.SolverFactoryClass": ... +@overload +def solver_factory(version: Literal[2]) -> "_appsi.SolverFactoryClass": ... +@overload +def solver_factory(version: Literal[3]) -> "_contrib.SolverFactoryClass": ... +@overload +def solver_factory(version: int) -> Any: ... + + def solver_factory(version=None): """Get (or set) the active implementation of the SolverFactory @@ -90,19 +113,19 @@ def solver_factory(version=None): if current is None: for ver, cls in versions.items(): if cls._cls is _environ.SolverFactory._cls: - solver_factory._active_version = ver + solver_factory._active_version = ver # type: ignore break - return solver_factory._active_version + return solver_factory._active_version # type: ignore # # The user is just asking what the current SolverFactory is; tell them. if version is None: - return solver_factory._active_version + return solver_factory._active_version # type: ignore # # Update the current SolverFactory to be a shim around (shallow copy # of) the new active factory src = versions.get(version, None) if version is not None: - solver_factory._active_version = version + solver_factory._active_version = version # type: ignore for attr in ('_description', '_cls', '_doc'): setattr(_environ.SolverFactory, attr, getattr(src, attr)) else: @@ -113,4 +136,4 @@ def solver_factory(version=None): return src -solver_factory._active_version = solver_factory() +solver_factory._active_version = solver_factory() # type: ignore diff --git a/pyomo/opt/base/solvers.py b/pyomo/opt/base/solvers.py index 158d6888f14..563376704e2 100644 --- a/pyomo/opt/base/solvers.py +++ b/pyomo/opt/base/solvers.py @@ -14,6 +14,7 @@ import time import logging import shlex +from typing import overload from pyomo.common import Factory from pyomo.common.enums import SolverAPIVersion @@ -144,6 +145,11 @@ def _solver_error(self, method_name): class SolverFactoryClass(Factory): + @overload + def __call__(self, _name: None = None, **kwds) -> "SolverFactoryClass": ... + @overload + def __call__(self, _name, **kwds) -> "OptSolver": ... + def __call__(self, _name=None, **kwds): if _name is None: return self