Skip to content

TEST: TypeVar defaults #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: my-main
Choose a base branch
from
10 changes: 7 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
@@ -6750,8 +6750,8 @@ def named_type(self, name: str) -> Instance:
assert isinstance(node.target, Instance) # type: ignore[misc]
node = node.target.type
assert isinstance(node, TypeInfo)
any_type = AnyType(TypeOfAny.from_omitted_generics)
return Instance(node, [any_type] * len(node.defn.type_vars))

return Instance(node, [type_var.default for type_var in node.defn.type_vars])

def named_generic_type(self, name: str, args: list[Type]) -> Instance:
"""Return an instance with the given name and type arguments.
@@ -6760,7 +6760,11 @@ def named_generic_type(self, name: str, args: list[Type]) -> Instance:
the name refers to a compatible generic type.
"""
info = self.lookup_typeinfo(name)
args = [remove_instance_last_known_values(arg) for arg in args]
args = [
remove_instance_last_known_values(arg)
for arg in args
# + [tv.default for tv in info.defn.type_vars[len(args) - len(info.defn.type_vars) :]]
]
# TODO: assert len(args) == len(info.defn.type_vars)
return Instance(info, args)

5 changes: 5 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
@@ -2542,6 +2542,11 @@ def check_argument_types(
callee_arg_kind,
allow_unpack=isinstance(callee_arg_type, UnpackType),
)
if (
isinstance(callee_arg_type, types.TypeVarType)
and callee_arg_type.has_default()
):
callee_arg_type = callee_arg_type.default
check_arg(
expanded_actual,
actual_type,
14 changes: 14 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
@@ -217,6 +218,8 @@ def _analyze_member_access(
# TODO: This and following functions share some logic with subtypes.find_member;
# consider refactoring.
typ = get_proper_type(typ)
# if name == "Bar":
# print("here")
if isinstance(typ, Instance):
return analyze_instance_member_access(name, typ, mx, override_info)
elif isinstance(typ, AnyType):
@@ -397,6 +400,17 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
ret_type, name, mx, original_vars=typ.items[0].variables, mcs_fallback=typ.fallback
)
if result:
if isinstance(result, CallableType):
from mypy.expandtype import expand_type

env: dict[TypeVarId, TypeVarLikeType] = {t.id: t for t in result.variables}
env.update(
{
t.id: type_
for t, type_ in zip(ret_type.type.defn.type_vars, ret_type.args)
}
)
return expand_type(result, env)
return result
# Look up from the 'type' type.
return _analyze_member_access(name, typ.fallback, mx)
2 changes: 2 additions & 0 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
@@ -328,6 +328,8 @@ def _infer_constraints(
template = mypy.typeops.make_simplified_union(template.items, keep_erased=True)
if isinstance(actual, UnionType):
actual = mypy.typeops.make_simplified_union(actual.items, keep_erased=True)
if isinstance(actual, TypeVarType) and actual.has_default():
actual = get_proper_type(actual.default)

# Ignore Any types from the type suggestion engine to avoid them
# causing us to infer Any in situations where a better job could
36 changes: 33 additions & 3 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,8 @@
UnpackType,
flatten_nested_unions,
get_proper_type,
has_param_specs,
has_type_var_like_default,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import split_with_instance
@@ -125,6 +127,11 @@ def freshen_function_type_vars(callee: F) -> F:
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
tv = v.new_unification_variable(v)
if isinstance(tv.default, tv.__class__):
try:
tv.default = tvmap[tv.default.id]
except KeyError:
tv.default = tv.default.default
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
@@ -179,6 +186,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
self.variables = variables
self.recursive_guard: set[Type | tuple[int, Type]] = set()

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
@@ -222,15 +230,37 @@ def visit_type_var(self, t: TypeVarType) -> Type:
if t.id.raw_id == 0:
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
if isinstance(repl, ProperType) and isinstance(repl, Instance):

if has_type_var_like_default(repl):
if repl in self.recursive_guard:
return repl
self.recursive_guard.add(repl)
repl = repl.accept(self)
if isinstance(repl, TypeVarType):
repl.default = repl.default.accept(self)

if isinstance(repl, Instance):
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
return repl

def visit_param_spec(self, t: ParamSpecType) -> Type:
# Set prefix to something empty, so we don't duplicate it below.
repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
# Set prefix to something empty so we don't duplicate below.
repl = get_proper_type(
self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
)

if has_param_specs(repl) and not isinstance(repl, Instance):
if (t.flavor, repl) in self.recursive_guard:
return repl
self.recursive_guard.add((t.flavor, repl))
repl = repl.accept(self)

if isinstance(repl, Instance):
# TODO: what does prefix mean in this case?
# TODO: why does this case even happen? Instances aren't plural.
return repl
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
10 changes: 10 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
@@ -2582,6 +2582,7 @@ class TypeVarTupleExpr(TypeVarLikeExpr):
__slots__ = "tuple_fallback"

tuple_fallback: mypy.types.Instance
default: mypy.types.UnpackType | mypy.types.AnyType # Unpack or Any

__match_args__ = ("name", "upper_bound", "default")

@@ -2852,6 +2853,7 @@ class is generic then it will be a type constructor of higher kind.
"fallback_to_any",
"meta_fallback_to_any",
"type_vars",
"has_type_var_default",
"has_param_spec_type",
"bases",
"_promote",
@@ -2956,6 +2958,9 @@ class is generic then it will be a type constructor of higher kind.
# Generic type variable names (full names)
type_vars: list[str]

# Whether this class has a TypeVar with a default value
has_type_var_default: bool

# Whether this class has a ParamSpec type variable
has_param_spec_type: bool

@@ -3043,6 +3048,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
self.defn = defn
self.module_name = module_name
self.type_vars = []
self.has_type_var_default = False
self.has_param_spec_type = False
self.has_type_var_tuple_type = False
self.bases = []
@@ -3084,6 +3090,8 @@ def add_type_vars(self) -> None:
self.has_type_var_tuple_type = False
if self.defn.type_vars:
for i, vd in enumerate(self.defn.type_vars):
if vd.has_default():
self.has_type_var_default = True
if isinstance(vd, mypy.types.ParamSpecType):
self.has_param_spec_type = True
if isinstance(vd, mypy.types.TypeVarTupleType):
@@ -3279,6 +3287,7 @@ def serialize(self) -> JsonDict:
"defn": self.defn.serialize(),
"abstract_attributes": self.abstract_attributes,
"type_vars": self.type_vars,
"has_type_var_default": self.has_type_var_default,
"has_param_spec_type": self.has_param_spec_type,
"bases": [b.serialize() for b in self.bases],
"mro": [c.fullname for c in self.mro],
@@ -3317,6 +3326,7 @@ def deserialize(cls, data: JsonDict) -> TypeInfo:
# TODO: Is there a reason to reconstruct ti.subtypes?
ti.abstract_attributes = [(attr[0], attr[1]) for attr in data["abstract_attributes"]]
ti.type_vars = data["type_vars"]
ti.has_type_var_default = data["has_type_var_default"]
ti.has_param_spec_type = data["has_param_spec_type"]
ti.bases = [mypy.types.Instance.deserialize(b) for b in data["bases"]]
_promote = []
86 changes: 52 additions & 34 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
@@ -1081,7 +1081,9 @@ def setup_self_type(self) -> None:
assert self.type is not None
info = self.type
if info.self_type is not None:
if has_placeholder(info.self_type.upper_bound):
if has_placeholder(info.self_type.upper_bound) or has_placeholder(
info.self_type.default
):
# Similar to regular (user defined) type variables.
self.process_placeholder(
None,
@@ -1642,9 +1644,10 @@ def analyze_class(self, defn: ClassDef) -> None:
# Some type variable bounds or values are not ready, we need
# to re-analyze this class.
self.defer()
if has_placeholder(tvd.default):
if isinstance(tvd, TypeVarLikeType) and has_placeholder(tvd.default):
# Placeholder values in TypeVarLikeTypes may get substituted in.
# Defer current target until they are ready.
self.defer()
self.mark_incomplete(defn.name, defn)
return

@@ -1954,7 +1957,16 @@ class Foo(Bar, Generic[T]): ...
del base_type_exprs[i]
tvar_defs: list[TypeVarLikeType] = []
for name, tvar_expr in declared_tvars:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if isinstance(tvar_expr.default, UnboundType):
# assumption here is that the names cannot be duplicated
for fullname, type_var in self.tvar_scope.scope.items():
_, _, default_type_var_name = fullname.rpartition(".")
if tvar_expr.default.name == default_type_var_name:
tvar_expr.default = type_var
# TODO(PEP 696) detect out of order typevars
tvar_def = self.tvar_scope.get_binding(tvar_expr.fullname)
if tvar_def is None:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_defs.append(tvar_def)
return base_type_exprs, tvar_defs, is_protocol

@@ -2016,21 +2028,33 @@ def analyze_unbound_tvar_impl(
if sym and isinstance(sym.node, PlaceholderNode):
self.record_incomplete_ref()
if not allow_tvt and sym and isinstance(sym.node, ParamSpecExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
# if (
# sym.fullname
# and not self.tvar_scope.allow_binding(sym.fullname)
# and self.tvar_scope.parent
# and self.tvar_scope.parent.allow_binding(sym.fullname)
# ):
# # It's bound by our type variable scope
# return None
return t.name, sym.node
if allow_tvt and sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
# if (
# sym.fullname
# and not self.tvar_scope.allow_binding(sym.fullname)
# and self.tvar_scope.parent
# and self.tvar_scope.parent.allow_binding(sym.fullname)
# ):
# # It's bound by our type variable scope
# return None
return t.name, sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr) or allow_tvt:
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
# elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# # It's bound by our type variable scope
# return None
else:
if isinstance(sym.node.default, sym.node.__class__):
return None
assert isinstance(sym.node, TypeVarExpr)
return t.name, sym.node

@@ -4131,6 +4155,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
type_var.line = call.line
call.analyzed = type_var
updated = True
# self.tvar_scope.bind_new(name, type_var) # TODO!
else:
assert isinstance(call.analyzed, TypeVarExpr)
updated = (
@@ -4139,6 +4164,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
or default != call.analyzed.default
)
call.analyzed.upper_bound = upper_bound
call.analyzed.default = default
call.analyzed.values = values
call.analyzed.default = default
if any(has_placeholder(v) for v in values):
@@ -4281,6 +4307,8 @@ def get_typevarlike_argument(
allow_unbound_tvars=allow_unbound_tvars,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
tvar_scope=self.tvar_scope,
allow_tuple_literal=True,
)
if analyzed is None:
# Type variables are special: we need to place them in the symbol table
@@ -4395,6 +4423,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:
paramspec_var.line = call.line
call.analyzed = paramspec_var
updated = True
# self.tvar_scope.bind_new(name, paramspec_var) # TODO!
else:
assert isinstance(call.analyzed, ParamSpecExpr)
updated = default != call.analyzed.default
@@ -4463,6 +4492,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool:
typevartuple_var.line = call.line
call.analyzed = typevartuple_var
updated = True
# self.tvar_scope.bind_new(name, typevartuple_var) # TODO!
else:
assert isinstance(call.analyzed, TypeVarTupleExpr)
updated = default != call.analyzed.default
@@ -4968,9 +4998,10 @@ def visit_name_expr(self, expr: NameExpr) -> None:

def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None:
"""Bind name expression to a symbol table node."""
if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym):
self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr)
elif isinstance(sym.node, PlaceholderNode):
# TODO renenable this check, its fine for defaults
# if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym):
# self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr)
if isinstance(sym.node, PlaceholderNode):
self.process_placeholder(expr.name, "name", expr)
else:
expr.kind = sym.kind
@@ -6551,16 +6582,7 @@ def accept(self, node: Node) -> None:
except Exception as err:
report_internal_error(err, self.errors.file, node.line, self.errors, self.options)

def expr_to_analyzed_type(
self,
expr: Expression,
report_invalid_types: bool = True,
allow_placeholder: bool = False,
allow_type_any: bool = False,
allow_unbound_tvars: bool = False,
allow_param_spec_literals: bool = False,
allow_unpack: bool = False,
) -> Type | None:
def expr_to_analyzed_type(self, expr: Expression, **kwargs: Any) -> Type | None:
if isinstance(expr, CallExpr):
# This is a legacy syntax intended mostly for Python 2, we keep it for
# backwards compatibility, but new features like generic named tuples
@@ -6582,16 +6604,12 @@ def expr_to_analyzed_type(
assert info.tuple_type, "NamedTuple without tuple type"
fallback = Instance(info, [])
return TupleType(info.tuple_type.items, fallback=fallback)
# print(expr)
typ = self.expr_to_unanalyzed_type(expr)
return self.anal_type(
typ,
report_invalid_types=report_invalid_types,
allow_placeholder=allow_placeholder,
allow_type_any=allow_type_any,
allow_unbound_tvars=allow_unbound_tvars,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
)
# print("type", typ)
analised = self.anal_type(typ, **kwargs)
# print(f"{analised=}")
return analised

def analyze_type_expr(self, expr: Expression) -> None:
# There are certain expressions that mypy does not need to semantically analyze,
Loading