diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md index 4ea9e7adf8..3191cf5683 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md @@ -170,3 +170,64 @@ type X[T: X] = T def _(x: X): assert x ``` + +## Recursive generic type aliases + +```py +type RecursiveList[T] = T | list[RecursiveList[T]] + +r1: RecursiveList[int] = 1 +r2: RecursiveList[int] = [1, [1, 2, 3]] +# error: [invalid-assignment] "Object of type `Literal["a"]` is not assignable to `RecursiveList[int]`" +r3: RecursiveList[int] = "a" +# error: [invalid-assignment] +r4: RecursiveList[int] = ["a"] +# TODO: this should be an error +r5: RecursiveList[int] = [1, ["a"]] + +def _(x: RecursiveList[int]): + if isinstance(x, list): + # TODO: should be `list[RecursiveList[int]] + reveal_type(x[0]) # revealed: int | list[Any] + if isinstance(x, list) and isinstance(x[0], list): + # TODO: should be `list[RecursiveList[int]]` + reveal_type(x[0]) # revealed: list[Any] +``` + +Assignment checks respect structural subtyping, i.e. type aliases with the same structure are +assignable to each other. + +```py +# This is structurally equivalent to RecursiveList[T]. +type RecursiveList2[T] = T | list[T | list[RecursiveList[T]]] +# This is not structurally equivalent to RecursiveList[T]. +type RecursiveList3[T] = T | list[list[RecursiveList[T]]] + +def _(x: RecursiveList[int], y: RecursiveList2[int]): + r1: RecursiveList2[int] = x + # error: [invalid-assignment] + r2: RecursiveList3[int] = x + + r3: RecursiveList[int] = y + # error: [invalid-assignment] + r4: RecursiveList3[int] = y +``` + +It is also possible to handle divergent type aliases that are not actually have instances. + +```py +# The type variable `T` has no meaning here, it's just to make sure it works correctly. +type DivergentList[T] = list[DivergentList[T]] + +d1: DivergentList[int] = [] +# error: [invalid-assignment] +d2: DivergentList[int] = [1] +# error: [invalid-assignment] +d3: DivergentList[int] = ["a"] +# TODO: this should be an error +d4: DivergentList[int] = [[1]] + +def _(x: DivergentList[int]): + d1: DivergentList[int] = [x] + d2: DivergentList[int] = x[0] +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 31cb931396..6cc7f20739 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -6772,7 +6772,11 @@ impl<'db> Type<'db> { Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping, tcx)), Type::TypeAlias(alias) => { - visitor.visit(self, || alias.value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)) + // Do not call `value_type` here. `value_type` does the specialization internally, so `apply_type_mapping` is performed without `visitor` inheritance. + // In the case of recursive type aliases, this leads to infinite recursion. + // Instead, call `raw_value_type` and perform the specialization after the `visitor` cache has been created. + let value_type = visitor.visit(self, || alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)); + alias.apply_function_specialization(db, value_type).apply_type_mapping_impl(db, type_mapping, tcx, visitor) } Type::ModuleLiteral(_) @@ -10716,31 +10720,12 @@ impl<'db> PEP695TypeAliasType<'db> { } /// The RHS type of a PEP-695 style type alias with specialization applied. - #[salsa::tracked(cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> { - let value_type = self.raw_value_type(db); - - if let Some(generic_context) = self.generic_context(db) { - let specialization = self - .specialization(db) - .unwrap_or_else(|| generic_context.default_specialization(db, None)); - - value_type.apply_specialization(db, specialization) - } else { - value_type - } + self.apply_function_specialization(db, self.raw_value_type(db)) } /// The RHS type of a PEP-695 style type alias with *no* specialization applied. - /// - /// ## Warning - /// - /// This uses the semantic index to find the definition of the type alias. This means that if the - /// calling query is not in the same file as this type alias is defined in, then this will create - /// a cross-module dependency directly on the full AST which will lead to cache - /// over-invalidation. - /// This method also calls the type inference functions, and since type aliases can have recursive structures, - /// we should be careful not to create infinite recursions in this method (or make it tracked if necessary). + #[salsa::tracked(cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.rhs_scope(db); let module = parsed_module(db, scope.file(db)).load(db); @@ -10750,6 +10735,17 @@ impl<'db> PEP695TypeAliasType<'db> { definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value) } + fn apply_function_specialization(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + if let Some(generic_context) = self.generic_context(db) { + let specialization = self + .specialization(db) + .unwrap_or_else(|| generic_context.default_specialization(db, None)); + ty.apply_specialization(db, specialization) + } else { + ty + } + } + pub(crate) fn apply_specialization( self, db: &'db dyn Db, @@ -10939,6 +10935,13 @@ impl<'db> TypeAliasType<'db> { } } + fn apply_function_specialization(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + match self { + TypeAliasType::PEP695(type_alias) => type_alias.apply_function_specialization(db, ty), + TypeAliasType::ManualPEP695(_) => ty, + } + } + pub(crate) fn apply_specialization( self, db: &'db dyn Db, @@ -11799,6 +11802,9 @@ type CovariantAlias[T] = Covariant[T] type ContravariantAlias[T] = Contravariant[T] type InvariantAlias[T] = Invariant[T] type BivariantAlias[T] = Bivariant[T] + +type RecursiveAlias[T] = None | list[RecursiveAlias[T]] +type RecursiveAlias2[T] = None | list[T] | list[RecursiveAlias2[T]] "#, ) .unwrap(); @@ -11829,5 +11835,19 @@ type BivariantAlias[T] = Bivariant[T] .variance_of(&db, get_bound_typevar(&db, bivariant)), TypeVarVariance::Bivariant ); + + let recursive = get_type_alias(&db, "RecursiveAlias"); + assert_eq!( + KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(recursive)) + .variance_of(&db, get_bound_typevar(&db, recursive)), + TypeVarVariance::Bivariant + ); + + let recursive2 = get_type_alias(&db, "RecursiveAlias2"); + assert_eq!( + KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(recursive2)) + .variance_of(&db, get_bound_typevar(&db, recursive2)), + TypeVarVariance::Invariant + ); } }