diff --git a/mypy/checker.py b/mypy/checker.py index 391f28e93b1d..3ed14f88c29e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6750,8 +6750,8 @@ def named_type(self, name: str) -> Instance: assert isinstance(node.target, Instance) # type: ignore[misc] node = node.target.type assert isinstance(node, TypeInfo) - any_type = AnyType(TypeOfAny.from_omitted_generics) - return Instance(node, [any_type] * len(node.defn.type_vars)) + + return Instance(node, [type_var.default for type_var in node.defn.type_vars]) def named_generic_type(self, name: str, args: list[Type]) -> Instance: """Return an instance with the given name and type arguments. @@ -6760,7 +6760,11 @@ def named_generic_type(self, name: str, args: list[Type]) -> Instance: the name refers to a compatible generic type. """ info = self.lookup_typeinfo(name) - args = [remove_instance_last_known_values(arg) for arg in args] + args = [ + remove_instance_last_known_values(arg) + for arg in args + # + [tv.default for tv in info.defn.type_vars[len(args) - len(info.defn.type_vars) :]] + ] # TODO: assert len(args) == len(info.defn.type_vars) return Instance(info, args) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ff7b7fa2ff58..4049a478bfe5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2542,6 +2542,11 @@ def check_argument_types( callee_arg_kind, allow_unpack=isinstance(callee_arg_type, UnpackType), ) + if ( + isinstance(callee_arg_type, types.TypeVarType) + and callee_arg_type.has_default() + ): + callee_arg_type = callee_arg_type.default check_arg( expanded_actual, actual_type, diff --git a/mypy/checkmember.py b/mypy/checkmember.py index c24edacf0ee1..719507e6633c 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -65,6 +65,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, @@ -217,6 +218,8 @@ def _analyze_member_access( # TODO: This and following functions share some logic with subtypes.find_member; # consider refactoring. typ = get_proper_type(typ) + # if name == "Bar": + # print("here") if isinstance(typ, Instance): return analyze_instance_member_access(name, typ, mx, override_info) elif isinstance(typ, AnyType): @@ -397,6 +400,17 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member ret_type, name, mx, original_vars=typ.items[0].variables, mcs_fallback=typ.fallback ) if result: + if isinstance(result, CallableType): + from mypy.expandtype import expand_type + + env: dict[TypeVarId, TypeVarLikeType] = {t.id: t for t in result.variables} + env.update( + { + t.id: type_ + for t, type_ in zip(ret_type.type.defn.type_vars, ret_type.args) + } + ) + return expand_type(result, env) return result # Look up from the 'type' type. return _analyze_member_access(name, typ.fallback, mx) diff --git a/mypy/constraints.py b/mypy/constraints.py index c4eba2ca1ede..e3ad1b18bf96 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -328,6 +328,8 @@ def _infer_constraints( template = mypy.typeops.make_simplified_union(template.items, keep_erased=True) if isinstance(actual, UnionType): actual = mypy.typeops.make_simplified_union(actual.items, keep_erased=True) + if isinstance(actual, TypeVarType) and actual.has_default(): + actual = get_proper_type(actual.default) # Ignore Any types from the type suggestion engine to avoid them # causing us to infer Any in situations where a better job could diff --git a/mypy/expandtype.py b/mypy/expandtype.py index d2d294fb77f3..6a8f0ec7bc29 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -37,6 +37,8 @@ UnpackType, flatten_nested_unions, get_proper_type, + has_param_specs, + has_type_var_like_default, split_with_prefix_and_suffix, ) from mypy.typevartuples import split_with_instance @@ -125,6 +127,11 @@ def freshen_function_type_vars(callee: F) -> F: tvmap: dict[TypeVarId, Type] = {} for v in callee.variables: tv = v.new_unification_variable(v) + if isinstance(tv.default, tv.__class__): + try: + tv.default = tvmap[tv.default.id] + except KeyError: + tv.default = tv.default.default tvs.append(tv) tvmap[v.id] = tv fresh = expand_type(callee, tvmap).copy_modified(variables=tvs) @@ -179,6 +186,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator): def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: self.variables = variables + self.recursive_guard: set[Type | tuple[int, Type]] = set() def visit_unbound_type(self, t: UnboundType) -> Type: return t @@ -222,15 +230,37 @@ def visit_type_var(self, t: TypeVarType) -> Type: if t.id.raw_id == 0: t = t.copy_modified(upper_bound=t.upper_bound.accept(self)) repl = self.variables.get(t.id, t) - if isinstance(repl, ProperType) and isinstance(repl, Instance): + + if has_type_var_like_default(repl): + if repl in self.recursive_guard: + return repl + self.recursive_guard.add(repl) + repl = repl.accept(self) + if isinstance(repl, TypeVarType): + repl.default = repl.default.accept(self) + + if isinstance(repl, Instance): # TODO: do we really need to do this? # If I try to remove this special-casing ~40 tests fail on reveal_type(). return repl.copy_modified(last_known_value=None) return repl def visit_param_spec(self, t: ParamSpecType) -> Type: - # Set prefix to something empty, so we don't duplicate it below. - repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], []))) + # Set prefix to something empty so we don't duplicate below. + repl = get_proper_type( + self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], []))) + ) + + if has_param_specs(repl) and not isinstance(repl, Instance): + if (t.flavor, repl) in self.recursive_guard: + return repl + self.recursive_guard.add((t.flavor, repl)) + repl = repl.accept(self) + + if isinstance(repl, Instance): + # TODO: what does prefix mean in this case? + # TODO: why does this case even happen? Instances aren't plural. + return repl if isinstance(repl, ParamSpecType): return repl.copy_modified( flavor=t.flavor, diff --git a/mypy/nodes.py b/mypy/nodes.py index 1c781320580a..d42faa640d47 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2582,6 +2582,7 @@ class TypeVarTupleExpr(TypeVarLikeExpr): __slots__ = "tuple_fallback" tuple_fallback: mypy.types.Instance + default: mypy.types.UnpackType | mypy.types.AnyType # Unpack or Any __match_args__ = ("name", "upper_bound", "default") @@ -2852,6 +2853,7 @@ class is generic then it will be a type constructor of higher kind. "fallback_to_any", "meta_fallback_to_any", "type_vars", + "has_type_var_default", "has_param_spec_type", "bases", "_promote", @@ -2956,6 +2958,9 @@ class is generic then it will be a type constructor of higher kind. # Generic type variable names (full names) type_vars: list[str] + # Whether this class has a TypeVar with a default value + has_type_var_default: bool + # Whether this class has a ParamSpec type variable has_param_spec_type: bool @@ -3043,6 +3048,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None self.defn = defn self.module_name = module_name self.type_vars = [] + self.has_type_var_default = False self.has_param_spec_type = False self.has_type_var_tuple_type = False self.bases = [] @@ -3084,6 +3090,8 @@ def add_type_vars(self) -> None: self.has_type_var_tuple_type = False if self.defn.type_vars: for i, vd in enumerate(self.defn.type_vars): + if vd.has_default(): + self.has_type_var_default = True if isinstance(vd, mypy.types.ParamSpecType): self.has_param_spec_type = True if isinstance(vd, mypy.types.TypeVarTupleType): @@ -3279,6 +3287,7 @@ def serialize(self) -> JsonDict: "defn": self.defn.serialize(), "abstract_attributes": self.abstract_attributes, "type_vars": self.type_vars, + "has_type_var_default": self.has_type_var_default, "has_param_spec_type": self.has_param_spec_type, "bases": [b.serialize() for b in self.bases], "mro": [c.fullname for c in self.mro], @@ -3317,6 +3326,7 @@ def deserialize(cls, data: JsonDict) -> TypeInfo: # TODO: Is there a reason to reconstruct ti.subtypes? ti.abstract_attributes = [(attr[0], attr[1]) for attr in data["abstract_attributes"]] ti.type_vars = data["type_vars"] + ti.has_type_var_default = data["has_type_var_default"] ti.has_param_spec_type = data["has_param_spec_type"] ti.bases = [mypy.types.Instance.deserialize(b) for b in data["bases"]] _promote = [] diff --git a/mypy/semanal.py b/mypy/semanal.py index 4bf9f0c3eabb..acc6c64bfa59 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1081,7 +1081,9 @@ def setup_self_type(self) -> None: assert self.type is not None info = self.type if info.self_type is not None: - if has_placeholder(info.self_type.upper_bound): + if has_placeholder(info.self_type.upper_bound) or has_placeholder( + info.self_type.default + ): # Similar to regular (user defined) type variables. self.process_placeholder( None, @@ -1642,9 +1644,10 @@ def analyze_class(self, defn: ClassDef) -> None: # Some type variable bounds or values are not ready, we need # to re-analyze this class. self.defer() - if has_placeholder(tvd.default): + if isinstance(tvd, TypeVarLikeType) and has_placeholder(tvd.default): # Placeholder values in TypeVarLikeTypes may get substituted in. # Defer current target until they are ready. + self.defer() self.mark_incomplete(defn.name, defn) return @@ -1954,7 +1957,16 @@ class Foo(Bar, Generic[T]): ... del base_type_exprs[i] tvar_defs: list[TypeVarLikeType] = [] for name, tvar_expr in declared_tvars: - tvar_def = self.tvar_scope.bind_new(name, tvar_expr) + if isinstance(tvar_expr.default, UnboundType): + # assumption here is that the names cannot be duplicated + for fullname, type_var in self.tvar_scope.scope.items(): + _, _, default_type_var_name = fullname.rpartition(".") + if tvar_expr.default.name == default_type_var_name: + tvar_expr.default = type_var + # TODO(PEP 696) detect out of order typevars + tvar_def = self.tvar_scope.get_binding(tvar_expr.fullname) + if tvar_def is None: + tvar_def = self.tvar_scope.bind_new(name, tvar_expr) tvar_defs.append(tvar_def) return base_type_exprs, tvar_defs, is_protocol @@ -2016,21 +2028,33 @@ def analyze_unbound_tvar_impl( if sym and isinstance(sym.node, PlaceholderNode): self.record_incomplete_ref() if not allow_tvt and sym and isinstance(sym.node, ParamSpecExpr): - if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname): - # It's bound by our type variable scope - return None + # if ( + # sym.fullname + # and not self.tvar_scope.allow_binding(sym.fullname) + # and self.tvar_scope.parent + # and self.tvar_scope.parent.allow_binding(sym.fullname) + # ): + # # It's bound by our type variable scope + # return None return t.name, sym.node if allow_tvt and sym and isinstance(sym.node, TypeVarTupleExpr): - if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname): - # It's bound by our type variable scope - return None + # if ( + # sym.fullname + # and not self.tvar_scope.allow_binding(sym.fullname) + # and self.tvar_scope.parent + # and self.tvar_scope.parent.allow_binding(sym.fullname) + # ): + # # It's bound by our type variable scope + # return None return t.name, sym.node if sym is None or not isinstance(sym.node, TypeVarExpr) or allow_tvt: return None - elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname): - # It's bound by our type variable scope - return None + # elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname): + # # It's bound by our type variable scope + # return None else: + if isinstance(sym.node.default, sym.node.__class__): + return None assert isinstance(sym.node, TypeVarExpr) return t.name, sym.node @@ -4131,6 +4155,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: type_var.line = call.line call.analyzed = type_var updated = True + # self.tvar_scope.bind_new(name, type_var) # TODO! else: assert isinstance(call.analyzed, TypeVarExpr) updated = ( @@ -4139,6 +4164,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: or default != call.analyzed.default ) call.analyzed.upper_bound = upper_bound + call.analyzed.default = default call.analyzed.values = values call.analyzed.default = default if any(has_placeholder(v) for v in values): @@ -4281,6 +4307,8 @@ def get_typevarlike_argument( allow_unbound_tvars=allow_unbound_tvars, allow_param_spec_literals=allow_param_spec_literals, allow_unpack=allow_unpack, + tvar_scope=self.tvar_scope, + allow_tuple_literal=True, ) if analyzed is None: # Type variables are special: we need to place them in the symbol table @@ -4395,6 +4423,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: paramspec_var.line = call.line call.analyzed = paramspec_var updated = True + # self.tvar_scope.bind_new(name, paramspec_var) # TODO! else: assert isinstance(call.analyzed, ParamSpecExpr) updated = default != call.analyzed.default @@ -4463,6 +4492,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool: typevartuple_var.line = call.line call.analyzed = typevartuple_var updated = True + # self.tvar_scope.bind_new(name, typevartuple_var) # TODO! else: assert isinstance(call.analyzed, TypeVarTupleExpr) updated = default != call.analyzed.default @@ -4968,9 +4998,10 @@ def visit_name_expr(self, expr: NameExpr) -> None: def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None: """Bind name expression to a symbol table node.""" - if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym): - self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr) - elif isinstance(sym.node, PlaceholderNode): + # TODO renenable this check, its fine for defaults + # if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym): + # self.fail(f'"{expr.name}" is a type variable and only valid in type context', expr) + if isinstance(sym.node, PlaceholderNode): self.process_placeholder(expr.name, "name", expr) else: expr.kind = sym.kind @@ -6551,16 +6582,7 @@ def accept(self, node: Node) -> None: except Exception as err: report_internal_error(err, self.errors.file, node.line, self.errors, self.options) - def expr_to_analyzed_type( - self, - expr: Expression, - report_invalid_types: bool = True, - allow_placeholder: bool = False, - allow_type_any: bool = False, - allow_unbound_tvars: bool = False, - allow_param_spec_literals: bool = False, - allow_unpack: bool = False, - ) -> Type | None: + def expr_to_analyzed_type(self, expr: Expression, **kwargs: Any) -> Type | None: if isinstance(expr, CallExpr): # This is a legacy syntax intended mostly for Python 2, we keep it for # backwards compatibility, but new features like generic named tuples @@ -6582,16 +6604,12 @@ def expr_to_analyzed_type( assert info.tuple_type, "NamedTuple without tuple type" fallback = Instance(info, []) return TupleType(info.tuple_type.items, fallback=fallback) + # print(expr) typ = self.expr_to_unanalyzed_type(expr) - return self.anal_type( - typ, - report_invalid_types=report_invalid_types, - allow_placeholder=allow_placeholder, - allow_type_any=allow_type_any, - allow_unbound_tvars=allow_unbound_tvars, - allow_param_spec_literals=allow_param_spec_literals, - allow_unpack=allow_unpack, - ) + # print("type", typ) + analised = self.anal_type(typ, **kwargs) + # print(f"{analised=}") + return analised def analyze_type_expr(self, expr: Expression) -> None: # There are certain expressions that mypy does not need to semantically analyze, diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 2d536f892a2a..f14c1b83e8fa 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -626,8 +626,9 @@ def visit_instance(self, left: Instance) -> bool: if call: return self._is_subtype(call, right) return False - else: - return False + if isinstance(right, TypeVarType) and right.has_default(): + return self._is_subtype(left, right.default) + return False def visit_type_var(self, left: TypeVarType) -> bool: right = self.right @@ -635,6 +636,9 @@ def visit_type_var(self, left: TypeVarType) -> bool: return True if left.values and self._is_subtype(UnionType.make_union(left.values), right): return True + if left.has_default(): + # Check if correct! + return self._is_subtype(left.default, self.right) return self._is_subtype(left.upper_bound, self.right) def visit_param_spec(self, left: ParamSpecType) -> bool: diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py index c7a653a1552d..f27532c2eef5 100644 --- a/mypy/tvar_scope.py +++ b/mypy/tvar_scope.py @@ -1,5 +1,8 @@ from __future__ import annotations +from copy import copy +from typing import Iterator + from mypy.nodes import ( ParamSpecExpr, SymbolTableNode, @@ -7,16 +10,130 @@ TypeVarLikeExpr, TypeVarTupleExpr, ) +from mypy.type_visitor import SyntheticTypeVisitor from mypy.types import ( + AnyType, + CallableArgument, + CallableType, + DeletedType, + EllipsisType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, ParamSpecFlavor, ParamSpecType, + PartialType, + PlaceholderType, + RawExpressionType, + TupleType, + TypeAliasType, + TypedDictType, + TypeList, + TypeType, TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, ) +class TypeVarLikeYielder(SyntheticTypeVisitor[Iterator[TypeVarLikeType]]): + """Yield all TypeVarLikeTypes in a type.""" + + def visit_type_var(self, t: TypeVarType) -> Iterator[TypeVarLikeType]: + yield t + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> Iterator[TypeVarLikeType]: + yield t + + def visit_param_spec(self, t: ParamSpecType) -> Iterator[TypeVarLikeType]: + yield t + + def visit_callable_type(self, t: CallableType) -> Iterator[TypeVarLikeType]: + for arg in t.arg_types: + yield from arg.accept(self) + yield from t.ret_type.accept(self) + + def visit_instance(self, t: Instance) -> Iterator[TypeVarLikeType]: + for arg in t.args: + yield from arg.accept(self) + + def visit_overloaded(self, t: Overloaded) -> Iterator[TypeVarLikeType]: + for item in t.items: + yield from item.accept(self) + + def visit_tuple_type(self, t: TupleType) -> Iterator[TypeVarLikeType]: + for item in t.items: + yield from item.accept(self) + + def visit_type_alias_type(self, t: TypeAliasType) -> Iterator[TypeVarLikeType]: + for arg in t.args: + yield from arg.accept(self) + + def visit_typeddict_type(self, t: TypedDictType) -> Iterator[TypeVarLikeType]: + for arg in t.items.values(): + yield from arg.accept(self) + + def visit_union_type(self, t: UnionType) -> Iterator[TypeVarLikeType]: + for arg in t.items: + yield from arg.accept(self) + + def visit_type_type(self, t: TypeType) -> Iterator[TypeVarLikeType]: + yield from t.item.accept(self) + + def visit_type_list(self, t: TypeList) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_callable_argument(self, t: CallableArgument) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_ellipsis_type(self, t: EllipsisType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_raw_expression_type(self, t: RawExpressionType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_unbound_type(self, t: UnboundType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_none_type(self, t: NoneType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_uninhabited_type(self, t: UninhabitedType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_erased_type(self, t: ErasedType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_deleted_type(self, t: DeletedType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_parameters(self, t: Parameters) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_literal_type(self, t: LiteralType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_partial_type(self, t: PartialType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_unpack_type(self, t: UnpackType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_any(self, t: AnyType) -> Iterator[TypeVarLikeType]: + yield from () + + def visit_placeholder_type(self, t: PlaceholderType) -> Iterator[TypeVarLikeType]: + yield from () + + class TypeVarLikeScope: """Scope that holds bindings for type variables and parameter specifications. @@ -82,12 +199,17 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: if self.is_class_scope: self.class_id += 1 i = self.class_id - namespace = self.namespace else: self.func_id -= 1 i = self.func_id - # TODO: Consider also using namespaces for functions - namespace = "" + namespace = self.namespace + # fix the namespace of any type vars + default = tvar_expr.default + + for tv in default.accept(TypeVarLikeYielder()): + tv = copy(tv) + tv.id.namespace = namespace + self.scope[tv.fullname] = tv if isinstance(tvar_expr, TypeVarExpr): tvar_def: TypeVarLikeType = TypeVarType( name=name, @@ -95,7 +217,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: id=TypeVarId(i, namespace=namespace), values=tvar_expr.values, upper_bound=tvar_expr.upper_bound, - default=tvar_expr.default, + default=default, variance=tvar_expr.variance, line=tvar_expr.line, column=tvar_expr.column, @@ -107,7 +229,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: i, flavor=ParamSpecFlavor.BARE, upper_bound=tvar_expr.upper_bound, - default=tvar_expr.default, + default=default, line=tvar_expr.line, column=tvar_expr.column, ) @@ -118,7 +240,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: i, upper_bound=tvar_expr.upper_bound, tuple_fallback=tvar_expr.tuple_fallback, - default=tvar_expr.default, + default=default, line=tvar_expr.line, column=tvar_expr.column, ) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 530793730f35..4269577df54e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1854,6 +1854,7 @@ def fix_instance( note: MsgCallback, disallow_any: bool, options: Options, + # tv_scope: TypeVarLikeScope, use_generic_error: bool = False, unexpanded_type: Type | None = None, ) -> None: diff --git a/mypy/types.py b/mypy/types.py index b1119c9447e2..eefb5e345da1 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -515,13 +515,14 @@ def new(meta_level: int) -> TypeVarId: return TypeVarId(raw_id, meta_level) def __repr__(self) -> str: + # return f"TypeVarId({self.raw_id}, {self.meta_level}, {self.namespace})" return self.raw_id.__repr__() def __eq__(self, other: object) -> bool: return ( isinstance(other, TypeVarId) and self.raw_id == other.raw_id - and self.meta_level == other.meta_level + and self.meta_level == other.meta_level # TODO this probably breaks a lot of stuff and self.namespace == other.namespace ) @@ -576,6 +577,8 @@ def copy_modified(self, *, id: TypeVarId, **kwargs: Any) -> Self: @classmethod def new_unification_variable(cls, old: Self) -> Self: new_id = TypeVarId.new(meta_level=1) + # new_id.raw_id = old.id.raw_id + new_id.namespace = old.id.namespace return old.copy_modified(id=new_id) def has_default(self) -> bool: @@ -3183,7 +3186,7 @@ def visit_instance(self, t: Instance) -> str: if t.args: if t.type.fullname == "builtins.tuple": - assert len(t.args) == 1 + # assert len(t.args) == 1 s += f"[{self.list_str(t.args)}, ...]" else: s += f"[{self.list_str(t.args)}]" @@ -3493,6 +3496,47 @@ def has_type_vars(typ: Type) -> bool: return typ.accept(HasTypeVars()) +class HasTypeVarLikeDefault(BoolTypeQuery): + def __init__(self) -> None: + super().__init__(ANY_STRATEGY) + self.skip_alias_target = True + + def visit_type_var(self, t: TypeVarType) -> bool: + return t.has_default() + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool: + return t.has_default() + + def visit_param_spec(self, t: ParamSpecType) -> bool: + return t.has_default() + + +# Use singleton since this is hot (note: call reset() before using) +_has_type_var_like_default: Final = HasTypeVarLikeDefault() + + +def has_type_var_like_default(typ: Type) -> bool: + _has_type_var_like_default.reset() + return typ.accept(_has_type_var_like_default) + + +class HasParamSpecs(BoolTypeQuery): + def __init__(self) -> None: + super().__init__(ANY_STRATEGY) + + def visit_param_spec(self, t: ParamSpecType) -> bool: + return True + + +# Use singleton since this is hot (note: call reset() before using) +_has_param_specs: Final = HasParamSpecs() + + +def has_param_specs(typ: Type) -> bool: + _has_param_specs.reset() + return typ.accept(_has_param_specs) + + class HasRecursiveType(BoolTypeQuery): def __init__(self) -> None: super().__init__(ANY_STRATEGY) diff --git a/mypy/typevars.py b/mypy/typevars.py index 3d74a40c303f..a2198554c75a 100644 --- a/mypy/typevars.py +++ b/mypy/typevars.py @@ -3,13 +3,11 @@ from mypy.erasetype import erase_typevars from mypy.nodes import TypeInfo from mypy.types import ( - AnyType, Instance, ParamSpecType, ProperType, TupleType, Type, - TypeOfAny, TypeVarLikeType, TypeVarTupleType, TypeVarType, @@ -64,16 +62,7 @@ def fill_typevars(typ: TypeInfo) -> Instance | TupleType: def fill_typevars_with_any(typ: TypeInfo) -> Instance | TupleType: """Apply a correct number of Any's as type arguments to a type.""" - args: list[Type] = [] - for tv in typ.defn.type_vars: - # Valid erasure for *Ts is *tuple[Any, ...], not just Any. - if isinstance(tv, TypeVarTupleType): - args.append( - UnpackType(tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])) - ) - else: - args.append(AnyType(TypeOfAny.special_form)) - inst = Instance(typ, args) + inst = Instance(typ, [tv.default for tv in typ.defn.type_vars]) if typ.tuple_type is None: return inst erased_tuple_type = erase_typevars(typ.tuple_type, {tv.id for tv in typ.defn.type_vars}) diff --git a/test.py b/test.py new file mode 100644 index 000000000000..e25e990f4472 --- /dev/null +++ b/test.py @@ -0,0 +1,47 @@ +from typing import Callable, Generic, TypeVar # noqa: F401 +from typing_extensions import ParamSpec, TypeAlias, TypeVarTuple, Unpack, reveal_type # noqa: F401 + +T = TypeVar("T", default=int) +# T2 = TypeVar("T2", default=Callable[[T], int]) +T3 = TypeVar("T3", default=T) +# T4 = TypeVar("T4", default=T3 | T) + +# class Foop(Generic[T, T3, T4]): ... + +# reveal_type(Foop[str]()) # TODO should be Foop[str, T3=str, T4=str] not Foop[str, T3=str, T4=T=int | str] + + +# reveal_type(Foop()) +# A = TypeVar("A") +# B = TypeVar("B") +# C = TypeVar("C", default=dict[A, B]) +class Foo(Generic[T]): + class Bar(Generic[T3]): ... + + +reveal_type(Foo[bool]) +reveal_type(Foo[bool].Bar) +reveal_type(Foo[bool]()) +reveal_type(Foo[bool]().Bar) +# reveal_type(Foo().Bar()) + +# reveal_type(Foo[int]()) + + +# reveal_type(Foo) # revealed type is type[__main__.Foo[T`1 = builtins.int, T2`2 = def (T`1 = builtins.int) -> builtins.int]] +# reveal_type(Foo[str]) # revealed type is type[__main__.Foo[builtins.str, T2`2 = def (builtins.str) -> builtins.int]] +# reveal_type(Foo[str, int]) # revealed type is type[__main__.Foo[builtins.str, int]] + +# PreSpecialised: TypeAlias = Foo[str] +# # reveal_type(PreSpecialised) # revealed type is type[__main__.Foo[builtins.str, T2`2 = def (builtins.str) -> builtins.int]] +# reveal_type(PreSpecialised[int]) # borked + +# P = ParamSpec("P", default=(int, str)) +# P2 = ParamSpec("P2", default=P) + +# class Bar(Generic[P, P2]): ... + +# reveal_type(Bar[(int,)]) +# def foo(fn: Callable[P, int]) -> bool: ... +# reveal_type(foo) # revealed type is def [P = [builtins.int, builtins.str]] (fn: def (*P.args, **P.kwargs) -> builtins.int) -> builtins.bool +# reveal_type(Bar) # revealed type is type[__main__.Bar[P`3 = [builtins.int, builtins.str], P2`4 = P`3 = [builtins.int, builtins.str]]]