Skip to content

TEST: TypeVar defaults #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: my-main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6750,8 +6750,8 @@ def named_type(self, name: str) -> Instance:
assert isinstance(node.target, Instance) # type: ignore[misc]
node = node.target.type
assert isinstance(node, TypeInfo)
any_type = AnyType(TypeOfAny.from_omitted_generics)
return Instance(node, [any_type] * len(node.defn.type_vars))

return Instance(node, [type_var.default for type_var in node.defn.type_vars])

def named_generic_type(self, name: str, args: list[Type]) -> Instance:
"""Return an instance with the given name and type arguments.
Expand All @@ -6760,7 +6760,11 @@ def named_generic_type(self, name: str, args: list[Type]) -> Instance:
the name refers to a compatible generic type.
"""
info = self.lookup_typeinfo(name)
args = [remove_instance_last_known_values(arg) for arg in args]
args = [
remove_instance_last_known_values(arg)
for arg in args
# + [tv.default for tv in info.defn.type_vars[len(args) - len(info.defn.type_vars) :]]
]
# TODO: assert len(args) == len(info.defn.type_vars)
return Instance(info, args)

Expand Down
5 changes: 5 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,11 @@ def check_argument_types(
callee_arg_kind,
allow_unpack=isinstance(callee_arg_type, UnpackType),
)
if (
isinstance(callee_arg_type, types.TypeVarType)
and callee_arg_type.has_default()
):
callee_arg_type = callee_arg_type.default
check_arg(
expanded_actual,
actual_type,
Expand Down
14 changes: 14 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -217,6 +218,8 @@ def _analyze_member_access(
# TODO: This and following functions share some logic with subtypes.find_member;
# consider refactoring.
typ = get_proper_type(typ)
# if name == "Bar":
# print("here")
if isinstance(typ, Instance):
return analyze_instance_member_access(name, typ, mx, override_info)
elif isinstance(typ, AnyType):
Expand Down Expand Up @@ -397,6 +400,17 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
ret_type, name, mx, original_vars=typ.items[0].variables, mcs_fallback=typ.fallback
)
if result:
if isinstance(result, CallableType):
from mypy.expandtype import expand_type

env: dict[TypeVarId, TypeVarLikeType] = {t.id: t for t in result.variables}
env.update(
{
t.id: type_
for t, type_ in zip(ret_type.type.defn.type_vars, ret_type.args)
}
)
return expand_type(result, env)
return result
# Look up from the 'type' type.
return _analyze_member_access(name, typ.fallback, mx)
Expand Down
2 changes: 2 additions & 0 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def _infer_constraints(
template = mypy.typeops.make_simplified_union(template.items, keep_erased=True)
if isinstance(actual, UnionType):
actual = mypy.typeops.make_simplified_union(actual.items, keep_erased=True)
if isinstance(actual, TypeVarType) and actual.has_default():
actual = get_proper_type(actual.default)

# Ignore Any types from the type suggestion engine to avoid them
# causing us to infer Any in situations where a better job could
Expand Down
36 changes: 33 additions & 3 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
UnpackType,
flatten_nested_unions,
get_proper_type,
has_param_specs,
has_type_var_like_default,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import split_with_instance
Expand Down Expand Up @@ -125,6 +127,11 @@ def freshen_function_type_vars(callee: F) -> F:
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
tv = v.new_unification_variable(v)
if isinstance(tv.default, tv.__class__):
try:
tv.default = tvmap[tv.default.id]
except KeyError:
tv.default = tv.default.default
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
Expand Down Expand Up @@ -179,6 +186,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
self.variables = variables
self.recursive_guard: set[Type | tuple[int, Type]] = set()

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand Down Expand Up @@ -222,15 +230,37 @@ def visit_type_var(self, t: TypeVarType) -> Type:
if t.id.raw_id == 0:
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
if isinstance(repl, ProperType) and isinstance(repl, Instance):

if has_type_var_like_default(repl):
if repl in self.recursive_guard:
return repl
self.recursive_guard.add(repl)
repl = repl.accept(self)
if isinstance(repl, TypeVarType):
repl.default = repl.default.accept(self)

if isinstance(repl, Instance):
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
return repl

def visit_param_spec(self, t: ParamSpecType) -> Type:
# Set prefix to something empty, so we don't duplicate it below.
repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
# Set prefix to something empty so we don't duplicate below.
repl = get_proper_type(
self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
)

if has_param_specs(repl) and not isinstance(repl, Instance):
if (t.flavor, repl) in self.recursive_guard:
return repl
self.recursive_guard.add((t.flavor, repl))
repl = repl.accept(self)

if isinstance(repl, Instance):
# TODO: what does prefix mean in this case?
# TODO: why does this case even happen? Instances aren't plural.
return repl
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
Expand Down
10 changes: 10 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,7 @@ class TypeVarTupleExpr(TypeVarLikeExpr):
__slots__ = "tuple_fallback"

tuple_fallback: mypy.types.Instance
default: mypy.types.UnpackType | mypy.types.AnyType # Unpack or Any

__match_args__ = ("name", "upper_bound", "default")

Expand Down Expand Up @@ -2852,6 +2853,7 @@ class is generic then it will be a type constructor of higher kind.
"fallback_to_any",
"meta_fallback_to_any",
"type_vars",
"has_type_var_default",
"has_param_spec_type",
"bases",
"_promote",
Expand Down Expand Up @@ -2956,6 +2958,9 @@ class is generic then it will be a type constructor of higher kind.
# Generic type variable names (full names)
type_vars: list[str]

# Whether this class has a TypeVar with a default value
has_type_var_default: bool

# Whether this class has a ParamSpec type variable
has_param_spec_type: bool

Expand Down Expand Up @@ -3043,6 +3048,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
self.defn = defn
self.module_name = module_name
self.type_vars = []
self.has_type_var_default = False
self.has_param_spec_type = False
self.has_type_var_tuple_type = False
self.bases = []
Expand Down Expand Up @@ -3084,6 +3090,8 @@ def add_type_vars(self) -> None:
self.has_type_var_tuple_type = False
if self.defn.type_vars:
for i, vd in enumerate(self.defn.type_vars):
if vd.has_default():
self.has_type_var_default = True
if isinstance(vd, mypy.types.ParamSpecType):
self.has_param_spec_type = True
if isinstance(vd, mypy.types.TypeVarTupleType):
Expand Down Expand Up @@ -3279,6 +3287,7 @@ def serialize(self) -> JsonDict:
"defn": self.defn.serialize(),
"abstract_attributes": self.abstract_attributes,
"type_vars": self.type_vars,
"has_type_var_default": self.has_type_var_default,
"has_param_spec_type": self.has_param_spec_type,
"bases": [b.serialize() for b in self.bases],
"mro": [c.fullname for c in self.mro],
Expand Down Expand Up @@ -3317,6 +3326,7 @@ def deserialize(cls, data: JsonDict) -> TypeInfo:
# TODO: Is there a reason to reconstruct ti.subtypes?
ti.abstract_attributes = [(attr[0], attr[1]) for attr in data["abstract_attributes"]]
ti.type_vars = data["type_vars"]
ti.has_type_var_default = data["has_type_var_default"]
ti.has_param_spec_type = data["has_param_spec_type"]
ti.bases = [mypy.types.Instance.deserialize(b) for b in data["bases"]]
_promote = []
Expand Down
86 changes: 52 additions & 34 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,9 @@ def setup_self_type(self) -> None:
assert self.type is not None
info = self.type
if info.self_type is not None:
if has_placeholder(info.self_type.upper_bound):
if has_placeholder(info.self_type.upper_bound) or has_placeholder(
info.self_type.default
):
# Similar to regular (user defined) type variables.
self.process_placeholder(
None,
Expand Down Expand Up @@ -1642,9 +1644,10 @@ def analyze_class(self, defn: ClassDef) -> None:
# Some type variable bounds or values are not ready, we need
# to re-analyze this class.
self.defer()
if has_placeholder(tvd.default):
if isinstance(tvd, TypeVarLikeType) and has_placeholder(tvd.default):
# Placeholder values in TypeVarLikeTypes may get substituted in.
# Defer current target until they are ready.
self.defer()
self.mark_incomplete(defn.name, defn)
return

Expand Down Expand Up @@ -1954,7 +1957,16 @@ class Foo(Bar, Generic[T]): ...
del base_type_exprs[i]
tvar_defs: list[TypeVarLikeType] = []
for name, tvar_expr in declared_tvars:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if isinstance(tvar_expr.default, UnboundType):
# assumption here is that the names cannot be duplicated
for fullname, type_var in self.tvar_scope.scope.items():
_, _, default_type_var_name = fullname.rpartition(".")
if tvar_expr.default.name == default_type_var_name:
tvar_expr.default = type_var
# TODO(PEP 696) detect out of order typevars
tvar_def = self.tvar_scope.get_binding(tvar_expr.fullname)
if tvar_def is None:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_defs.append(tvar_def)
return base_type_exprs, tvar_defs, is_protocol

Expand Down Expand Up @@ -2016,21 +2028,33 @@ def analyze_unbound_tvar_impl(
if sym and isinstance(sym.node, PlaceholderNode):
self.record_incomplete_ref()
if not allow_tvt and sym and isinstance(sym.node, ParamSpecExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
# if (
# sym.fullname
# and not self.tvar_scope.allow_binding(sym.fullname)
# and self.tvar_scope.parent
# and self.tvar_scope.parent.allow_binding(sym.fullname)
# ):
# # It's bound by our type variable scope
# return None
return t.name, sym.node
if allow_tvt and sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
# if (
# sym.fullname
# and not self.tvar_scope.allow_binding(sym.fullname)
# and self.tvar_scope.parent
# and self.tvar_scope.parent.allow_binding(sym.fullname)
# ):
# # It's bound by our type variable scope
# return None
return t.name, sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr) or allow_tvt:
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
# elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# # It's bound by our type variable scope
# return None
else:
if isinstance(sym.node.default, sym.node.__class__):
return None
assert isinstance(sym.node, TypeVarExpr)
return t.name, sym.node

Expand Down Expand Up @@ -4131,6 +4155,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
type_var.line = call.line
call.analyzed = type_var
updated = True
# self.tvar_scope.bind_new(name, type_var) # TODO!
else:
assert isinstance(call.analyzed, TypeVarExpr)
updated = (
Expand All @@ -4139,6 +4164,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
or default != call.analyzed.default
)
call.analyzed.upper_bound = upper_bound
call.analyzed.default = default
call.analyzed.values = values
call.analyzed.default = default
if any(has_placeholder(v) for v in values):
Expand Down Expand Up @@ -4281,6 +4307,8 @@ def get_typevarlike_argument(
allow_unbound_tvars=allow_unbound_tvars,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
tvar_scope=self.tvar_scope,
allow_tuple_literal=True,
)
if analyzed is None:
# Type variables are special: we need to place them in the symbol table
Expand Down Expand Up @@ -4395,6 +4423,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:
paramspec_var.line = call.line
call.analyzed = paramspec_var
updated = True
# self.tvar_scope.bind_new(name, paramspec_var) # TODO!
else:
assert isinstance(call.analyzed, ParamSpecExpr)
updated = default != call.analyzed.default
Expand Down Expand Up @@ -4463,6 +4492,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool:
typevartuple_var.line = call.line
call.analyzed = typevartuple_var
updated = True
# self.tvar_scope.bind_new(name, typevartuple_var) # TODO!
else:
assert isinstance(call.analyzed, TypeVarTupleExpr)
updated = default != call.analyzed.default
Expand Down Expand Up @@ -4968,9 +4998,10 @@ def visit_name_expr(self, expr: NameExpr) -> None:

def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None:
"""Bind name expression to a symbol table node."""
if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym):
self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr)
elif isinstance(sym.node, PlaceholderNode):
# TODO renenable this check, its fine for defaults
# if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym):
# self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr)
if isinstance(sym.node, PlaceholderNode):
self.process_placeholder(expr.name, "name", expr)
else:
expr.kind = sym.kind
Expand Down Expand Up @@ -6551,16 +6582,7 @@ def accept(self, node: Node) -> None:
except Exception as err:
report_internal_error(err, self.errors.file, node.line, self.errors, self.options)

def expr_to_analyzed_type(
self,
expr: Expression,
report_invalid_types: bool = True,
allow_placeholder: bool = False,
allow_type_any: bool = False,
allow_unbound_tvars: bool = False,
allow_param_spec_literals: bool = False,
allow_unpack: bool = False,
) -> Type | None:
def expr_to_analyzed_type(self, expr: Expression, **kwargs: Any) -> Type | None:
if isinstance(expr, CallExpr):
# This is a legacy syntax intended mostly for Python 2, we keep it for
# backwards compatibility, but new features like generic named tuples
Expand All @@ -6582,16 +6604,12 @@ def expr_to_analyzed_type(
assert info.tuple_type, "NamedTuple without tuple type"
fallback = Instance(info, [])
return TupleType(info.tuple_type.items, fallback=fallback)
# print(expr)
typ = self.expr_to_unanalyzed_type(expr)
return self.anal_type(
typ,
report_invalid_types=report_invalid_types,
allow_placeholder=allow_placeholder,
allow_type_any=allow_type_any,
allow_unbound_tvars=allow_unbound_tvars,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
)
# print("type", typ)
analised = self.anal_type(typ, **kwargs)
# print(f"{analised=}")
return analised

def analyze_type_expr(self, expr: Expression) -> None:
# There are certain expressions that mypy does not need to semantically analyze,
Expand Down
Loading