diff --git a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md index fdf8fc67a6..c8c5341d57 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md +++ b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md @@ -310,7 +310,7 @@ x11: list[Literal[1] | Literal[2] | Literal[3]] = [1, 2, 3] reveal_type(x11) # revealed: list[Literal[1, 2, 3]] x12: Y[Y[Literal[1]]] = [[1]] -reveal_type(x12) # revealed: list[list[Literal[1]]] +reveal_type(x12) # revealed: list[Y[Literal[1]]] x13: list[tuple[Literal[1], Literal[2], Literal[3]]] = [(1, 2, 3)] reveal_type(x13) # revealed: list[tuple[Literal[1], Literal[2], Literal[3]]] diff --git a/crates/ty_python_semantic/resources/mdtest/type_display/callable.md b/crates/ty_python_semantic/resources/mdtest/type_display/callable.md index 035d822101..d90911213d 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_display/callable.md +++ b/crates/ty_python_semantic/resources/mdtest/type_display/callable.md @@ -67,3 +67,53 @@ def _(x: object): c = C(x) reveal_type(c) # revealed: C[Top[(...)]] ``` + +## Type aliases are not expanded unless necessary + +```toml +[environment] +python-version = "3.12" +``` + +```py +type Scalar = int | float +type Array1d = list[Scalar] | tuple[Scalar] + +def f(x: Scalar | Array1d) -> None: + pass + +reveal_type(f) # revealed: def f(x: Scalar | Array1d) -> None + +class Foo: + def f(self, x: Scalar | Array1d) -> None: + pass + +reveal_type(Foo().f) # revealed: bound method Foo.f(x: Scalar | Array1d) -> None + +type ArrayNd = Scalar | list[ArrayNd] | tuple[ArrayNd] + +def g(x: Scalar | ArrayNd) -> None: + pass + +reveal_type(g) # revealed: def g(x: Scalar | ArrayNd) -> None + +class Bar: + def g(self, x: Scalar | ArrayNd) -> None: + pass + +# TODO: should be `bound method Bar.g(x: Scalar | ArrayNd) -> None` +reveal_type(Bar().g) # revealed: bound method Bar.g(x: Scalar | list[Any] | tuple[Any]) -> None + +type GenericArray1d[T] = list[T] | tuple[T] + +def h(x: Scalar | GenericArray1d[Scalar]) -> None: + pass + +reveal_type(h) # revealed: def h(x: Scalar | GenericArray1d[Scalar]) -> None + +class Baz: + def h(self, x: Scalar | GenericArray1d[Scalar]) -> None: + pass + +reveal_type(Baz().h) # revealed: bound method Baz.h(x: Scalar | GenericArray1d[Scalar]) -> None +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index d98b3f4185..13d62425b5 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -8070,7 +8070,7 @@ impl<'db> Type<'db> { Type::PropertyInstance(property.apply_type_mapping_impl(db, type_mapping, tcx, visitor)) } - Type::Union(union) => union.map(db, |element| { + Type::Union(union) => union.map_leave_aliases(db, |element| { element.apply_type_mapping_impl(db, type_mapping, tcx, visitor) }), Type::Intersection(intersection) => { @@ -8099,7 +8099,16 @@ impl<'db> Type<'db> { // In the case of recursive type aliases, this leads to infinite recursion. // Instead, call `raw_value_type` and perform the specialization after the `visitor` cache has been created. let value_type = visitor.visit(self, || alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)); - alias.apply_function_specialization(db, value_type).apply_type_mapping_impl(db, type_mapping, tcx, visitor) + let mapped = alias.apply_function_specialization(db, value_type).apply_type_mapping_impl(db, type_mapping, tcx, visitor); + let is_recursive = any_over_type(db, alias.raw_value_type(db).expand_eagerly(db), &|ty| ty.is_divergent(), false); + // If the type mapping does not result in any change to this (non-recursive) type alias, do not expand it. + // TODO: The rule that recursive type aliases must be expanded could potentially be removed, + // but doing so would currently cause a stack overflow, as the current recursive type alias specialization/expansion mechanism is incomplete. + if !is_recursive && alias.value_type(db) == mapped { + self + } else { + mapped + } } Type::ModuleLiteral(_) @@ -14012,6 +14021,23 @@ impl<'db> UnionType<'db> { .build() } + /// A version of [`UnionType::map`] that does not unpack type aliases. + pub(crate) fn map_leave_aliases( + self, + db: &'db dyn Db, + transform_fn: impl FnMut(&Type<'db>) -> Type<'db>, + ) -> Type<'db> { + self.elements(db) + .iter() + .map(transform_fn) + .fold( + UnionBuilder::new(db).unpack_aliases(false), + UnionBuilder::add, + ) + .recursively_defined(self.recursively_defined(db)) + .build() + } + /// A fallible version of [`UnionType::map`]. /// /// For each element in `self`, `transform_fn` is called on that element.