diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index f2eb886223..18a645b31f 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2325,7 +2325,7 @@ class C: def copy(self, other: "C"): self.x = other.x -reveal_type(C().x) # revealed: Unknown | Literal[1] +reveal_type(C().x) # revealed: Literal[1] | Unknown ``` If the only assignment to a name is cyclic, we just infer `Unknown` for that attribute: @@ -2381,7 +2381,7 @@ class B: def copy(self, other: "A"): self.x = other.x -reveal_type(B().x) # revealed: Unknown | Literal[1] +reveal_type(B().x) # revealed: Literal[1] | Unknown reveal_type(A().x) # revealed: Unknown | Literal[1] class Base: @@ -2400,7 +2400,7 @@ class C2: def replace_with(self, other: "C2"): self.x = other.x.flip() -reveal_type(C2(Sub()).x) # revealed: Unknown | Base +reveal_type(C2(Sub()).x) # revealed: Base | Unknown class C3: def __init__(self, x: Sub): @@ -2409,8 +2409,8 @@ class C3: def replace_with(self, other: "C3"): self.x = [self.x[0].flip()] -# TODO: should be `Unknown | list[Unknown | Sub] | list[Unknown | Base]` -reveal_type(C3(Sub()).x) # revealed: Unknown | list[Unknown | Sub] | list[Divergent] +# TODO: should be `list[Unknown | Sub] | list[Unknown | Base] | Unknown` +reveal_type(C3(Sub()).x) # revealed: list[Unknown | Sub] | list[Divergent] | Unknown ``` And cycles between many attributes: @@ -2453,13 +2453,13 @@ class ManyCycles: self.x6 = self.x1 + self.x2 + self.x3 + self.x4 + self.x5 + self.x7 self.x7 = self.x1 + self.x2 + self.x3 + self.x4 + self.x5 + self.x6 - reveal_type(self.x1) # revealed: Unknown | int - reveal_type(self.x2) # revealed: Unknown | int - reveal_type(self.x3) # revealed: Unknown | int - reveal_type(self.x4) # revealed: Unknown | int - reveal_type(self.x5) # revealed: Unknown | int - reveal_type(self.x6) # revealed: Unknown | int - reveal_type(self.x7) # revealed: Unknown | int + reveal_type(self.x1) # revealed: int | Unknown + reveal_type(self.x2) # revealed: int | Unknown + reveal_type(self.x3) # revealed: int | Unknown + reveal_type(self.x4) # revealed: int | Unknown + reveal_type(self.x5) # revealed: int | Unknown + reveal_type(self.x6) # revealed: int | Unknown + reveal_type(self.x7) # revealed: int | Unknown class ManyCycles2: def __init__(self: "ManyCycles2"): @@ -2468,8 +2468,8 @@ class ManyCycles2: self.x3 = [1] def f1(self: "ManyCycles2"): - # TODO: should be Unknown | list[Unknown | int] | list[Divergent] - reveal_type(self.x3) # revealed: Unknown | list[Unknown | int] | list[Divergent] | list[Divergent] + # TODO: should be list[Unknown | int] | list[Divergent] | Unknown + reveal_type(self.x3) # revealed: list[Unknown | int] | list[Divergent] | list[Divergent] | Unknown self.x1 = [self.x2] + [self.x3] self.x2 = [self.x1] + [self.x3] @@ -2528,7 +2528,7 @@ class Counter: def increment(self: "Counter"): self.count = self.count + 1 -reveal_type(Counter().count) # revealed: Unknown | int +reveal_type(Counter().count) # revealed: int | Unknown ``` We also handle infinitely nested generics: @@ -2541,7 +2541,7 @@ class NestedLists: def f(self: "NestedLists"): self.x = [self.x] -reveal_type(NestedLists().x) # revealed: Unknown | Literal[1] | list[Divergent] +reveal_type(NestedLists().x) # revealed: Literal[1] | list[Divergent] | Unknown class NestedMixed: def f(self: "NestedMixed"): @@ -2550,7 +2550,7 @@ class NestedMixed: def g(self: "NestedMixed"): self.x = {self.x} -reveal_type(NestedMixed().x) # revealed: Unknown | list[Divergent] | set[Divergent] +reveal_type(NestedMixed().x) # revealed: list[Divergent] | set[Divergent] | Unknown ``` And cases where the types originate from annotations: @@ -2567,7 +2567,7 @@ class NestedLists2: def f(self: "NestedLists2"): self.x = make_list(self.x) -reveal_type(NestedLists2().x) # revealed: Unknown | list[Divergent] +reveal_type(NestedLists2().x) # revealed: list[Divergent] | Unknown ``` ### Builtin types attributes @@ -2673,8 +2673,8 @@ class C: def f(self, other: "C"): self.x = (other.x, 1) -reveal_type(C().x) # revealed: Unknown | tuple[Divergent, Literal[1]] -reveal_type(C().x[0]) # revealed: Unknown | Divergent +reveal_type(C().x) # revealed: tuple[Divergent, Literal[1]] | Unknown +reveal_type(C().x[0]) # revealed: Divergent | Unknown ``` This also works if the tuple is not constructed directly: @@ -2691,7 +2691,7 @@ class D: def f(self, other: "D"): self.x = make_tuple(other.x) -reveal_type(D().x) # revealed: Unknown | tuple[Divergent, Literal[1]] +reveal_type(D().x) # revealed: tuple[Divergent, Literal[1]] | Unknown ``` The tuple type may also expand exponentially "in breadth": @@ -2704,7 +2704,7 @@ class E: def f(self: "E"): self.x = duplicate(self.x) -reveal_type(E().x) # revealed: Unknown | tuple[Divergent, Divergent] +reveal_type(E().x) # revealed: tuple[Divergent, Divergent] | Unknown ``` And it also works for homogeneous tuples: @@ -2717,7 +2717,7 @@ class F: def f(self, other: "F"): self.x = make_homogeneous_tuple(other.x) -reveal_type(F().x) # revealed: Unknown | tuple[Divergent, ...] +reveal_type(F().x) # revealed: tuple[Divergent, ...] | Unknown ``` ## Attributes of standard library modules that aren't yet defined diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index 8d722288e4..4f374ac754 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -227,22 +227,17 @@ def _(literals_2: Literal[0, 1], b: bool, flag: bool): literals_16 = 4 * literals_4 + literals_4 # Literal[0, 1, .., 15] literals_64 = 4 * literals_16 + literals_4 # Literal[0, 1, .., 63] literals_128 = 2 * literals_64 + literals_2 # Literal[0, 1, .., 127] - literals_256 = 2 * literals_128 + literals_2 # Literal[0, 1, .., 255] - # Going beyond the MAX_UNION_LITERALS limit (currently 512): - literals_512 = 2 * literals_256 + literals_2 # Literal[0, 1, .., 511] - reveal_type(literals_512 if flag else 512) # revealed: int + # Going beyond the MAX_UNION_LITERALS limit (currently 200): + literals_256 = 16 * literals_16 + literals_16 + reveal_type(literals_256) # revealed: int # Going beyond the limit when another type is already part of the union bool_and_literals_128 = b if flag else literals_128 # bool | Literal[0, 1, ..., 127] literals_128_shifted = literals_128 + 128 # Literal[128, 129, ..., 255] - literals_256_shifted = literals_256 + 256 # Literal[256, 257, ..., 511] # Now union the two: - two = bool_and_literals_128 if flag else literals_128_shifted - # revealed: bool | Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] - reveal_type(two) - reveal_type(two if flag else literals_256_shifted) # revealed: int + reveal_type(bool_and_literals_128 if flag else literals_128_shifted) # revealed: int ``` ## Simplifying gradually-equivalent types diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index 7d1686fb2d..6878c803c0 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -28,7 +28,7 @@ class Point: self.x, self.y = other.x, other.y p = Point() -reveal_type(p.x) # revealed: Unknown | int +reveal_type(p.x) # revealed: int | Unknown reveal_type(p.y) # revealed: Unknown | int ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md index 57fc838498..a8c69f1f7f 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/aliases.md @@ -189,7 +189,7 @@ r5: RecursiveList[int] = [1, ["a"]] def _(x: RecursiveList[int]): if isinstance(x, list): # TODO: should be `list[RecursiveList[int]] - reveal_type(x[0]) # revealed: int | list[Any] + reveal_type(x[0]) # revealed: list[Any] | int if isinstance(x, list) and isinstance(x[0], list): # TODO: should be `list[RecursiveList[int]]` reveal_type(x[0]) # revealed: list[Any] diff --git a/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md b/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md index d4e4fafc73..77c1817ddd 100644 --- a/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md @@ -248,7 +248,7 @@ IntOrStr = TypeAliasType(get_name(), int | str) type OptNestedInt = int | tuple[OptNestedInt, ...] | None def f(x: OptNestedInt) -> None: - reveal_type(x) # revealed: int | tuple[OptNestedInt, ...] | None + reveal_type(x) # revealed: int | None | tuple[OptNestedInt, ...] if x is not None: reveal_type(x) # revealed: int | tuple[OptNestedInt, ...] ``` @@ -327,7 +327,7 @@ class B(A[Alias]): def f(b: B): reveal_type(b) # revealed: B - reveal_type(b.attr) # revealed: list[Alias] | int + reveal_type(b.attr) # revealed: int | list[Alias] ``` ### Mutually recursive @@ -344,12 +344,12 @@ def f(x: A): reveal_type(y) # revealed: tuple[A] def g(x: A | B): - reveal_type(x) # revealed: tuple[B] | None + reveal_type(x) # revealed: None | tuple[B] from ty_extensions import Intersection def h(x: Intersection[A, B]): - reveal_type(x) # revealed: tuple[B] | None + reveal_type(x) # revealed: None | tuple[B] ``` ### Self-recursive callable type @@ -450,5 +450,5 @@ type Y = X | str | dict[str, Y] def _(y: Y): if isinstance(y, dict): - reveal_type(y) # revealed: dict[str, X] | dict[str, Y] + reveal_type(y) # revealed: dict[str, Y] | dict[str, X] ``` diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index cfa4c68914..a5888995bc 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -3029,7 +3029,7 @@ class B(A[P[int]]): def f(b: B): reveal_type(b) # revealed: B reveal_type(b.attr) # revealed: P[int] - reveal_type(b.attr.attr) # revealed: P[int] | int + reveal_type(b.attr.attr) # revealed: int | P[int] ``` ### Recursive generic protocols with property members diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index da6279211d..a4ce2b46ba 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -13356,7 +13356,8 @@ impl<'db> IntersectionType<'db> { .map(|ty| ty.normalized_impl(db, visitor)) .collect(); - elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); + elements + .sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r, false)); elements } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 64ca36010a..956343ce35 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -221,11 +221,12 @@ impl RecursivelyDefined { } } -/// If the value ​​is defined recursively, widening is performed from fewer literal elements, resulting in faster convergence of the fixed-point iteration. -const MAX_RECURSIVE_UNION_LITERALS: usize = 10; -/// If the value ​​is defined non-recursively, the fixed-point iteration will converge in one go, -/// so in principle we can have as many literal elements as we want, but to avoid unintended huge computational loads, we limit it to 256. -const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256; +// TODO increase this once we extend `UnionElement` throughout all union/intersection +// representations, so that we can make large unions of literals fast in all operations. +// +// For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number +// below 200, which is the salsa fixpoint iteration limit. +const MAX_UNION_LITERALS: usize = 190; pub(crate) struct UnionBuilder<'db> { elements: Vec>, @@ -283,27 +284,6 @@ impl<'db> UnionBuilder<'db> { self.elements.push(UnionElement::Type(Type::object())); } - fn widen_literal_types(&mut self, seen_aliases: &mut Vec>) { - let mut replace_with = vec![]; - for elem in &self.elements { - match elem { - UnionElement::IntLiterals(_) => { - replace_with.push(KnownClass::Int.to_instance(self.db)); - } - UnionElement::StringLiterals(_) => { - replace_with.push(KnownClass::Str.to_instance(self.db)); - } - UnionElement::BytesLiterals(_) => { - replace_with.push(KnownClass::Bytes.to_instance(self.db)); - } - UnionElement::Type(_) => {} - } - } - for ty in replace_with { - self.add_in_place_impl(ty, seen_aliases); - } - } - /// Adds a type to this union. pub(crate) fn add(mut self, ty: Type<'db>) -> Self { self.add_in_place(ty); @@ -316,15 +296,6 @@ impl<'db> UnionBuilder<'db> { } pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec>) { - let cycle_recovery = self.cycle_recovery; - let should_widen = |literals, recursively_defined: RecursivelyDefined| { - if recursively_defined.is_yes() && cycle_recovery { - literals >= MAX_RECURSIVE_UNION_LITERALS - } else { - literals >= MAX_NON_RECURSIVE_UNION_LITERALS - } - }; - match ty { Type::Union(union) => { let new_elements = union.elements(self.db); @@ -335,17 +306,6 @@ impl<'db> UnionBuilder<'db> { self.recursively_defined = self .recursively_defined .or(union.recursively_defined(self.db)); - if self.cycle_recovery && self.recursively_defined.is_yes() { - let literals = self.elements.iter().fold(0, |acc, elem| match elem { - UnionElement::IntLiterals(literals) => acc + literals.len(), - UnionElement::StringLiterals(literals) => acc + literals.len(), - UnionElement::BytesLiterals(literals) => acc + literals.len(), - UnionElement::Type(_) => acc, - }); - if should_widen(literals, self.recursively_defined) { - self.widen_literal_types(seen_aliases); - } - } } // Adding `Never` to a union is a no-op. Type::Never => {} @@ -369,7 +329,7 @@ impl<'db> UnionBuilder<'db> { for (index, element) in self.elements.iter_mut().enumerate() { match element { UnionElement::StringLiterals(literals) => { - if should_widen(literals.len(), self.recursively_defined) { + if literals.len() >= MAX_UNION_LITERALS { let replace_with = KnownClass::Str.to_instance(self.db); self.add_in_place_impl(replace_with, seen_aliases); return; @@ -414,7 +374,7 @@ impl<'db> UnionBuilder<'db> { for (index, element) in self.elements.iter_mut().enumerate() { match element { UnionElement::BytesLiterals(literals) => { - if should_widen(literals.len(), self.recursively_defined) { + if literals.len() >= MAX_UNION_LITERALS { let replace_with = KnownClass::Bytes.to_instance(self.db); self.add_in_place_impl(replace_with, seen_aliases); return; @@ -459,7 +419,7 @@ impl<'db> UnionBuilder<'db> { for (index, element) in self.elements.iter_mut().enumerate() { match element { UnionElement::IntLiterals(literals) => { - if should_widen(literals.len(), self.recursively_defined) { + if literals.len() >= MAX_UNION_LITERALS { let replace_with = KnownClass::Int.to_instance(self.db); self.add_in_place_impl(replace_with, seen_aliases); return; @@ -629,7 +589,13 @@ impl<'db> UnionBuilder<'db> { self.try_build().unwrap_or(Type::Never) } - pub(crate) fn try_build(self) -> Option> { + pub(crate) fn try_build(mut self) -> Option> { + // If the type is defined recursively, the union type is sorted and normalized. + // This is because the execution order of the queries is not deterministic and may result in a different order of elements. + // The order of the union type does not affect the type check result, but unstable output is undesirable. + if self.cycle_recovery && self.recursively_defined.is_yes() { + self.order_elements = true; + } let mut types = vec![]; for element in self.elements { match element { @@ -646,7 +612,9 @@ impl<'db> UnionBuilder<'db> { } } if self.order_elements { - types.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(self.db, l, r)); + types.sort_unstable_by(|l, r| { + union_or_intersection_elements_ordering(self.db, l, r, self.cycle_recovery) + }); } match types.len() { 0 => None, diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index d2c9a71208..66c10e87d5 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -26,17 +26,22 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( db: &'db dyn Db, left: &Type<'db>, right: &Type<'db>, + cycle_recovery: bool, ) -> Ordering { - debug_assert_eq!( - *left, - left.normalized(db), - "`left` must be normalized before a meaningful ordering can be established" - ); - debug_assert_eq!( - *right, - right.normalized(db), - "`right` must be normalized before a meaningful ordering can be established" - ); + // If we sort union types in a cycle recovery function, this check is not necessary + // because the purpose is to stabilize the output and the sort order itself is not important. + if !cycle_recovery { + debug_assert_eq!( + *left, + left.normalized(db), + "`left` must be normalized before a meaningful ordering can be established" + ); + debug_assert_eq!( + *right, + right.normalized(db), + "`right` must be normalized before a meaningful ordering can be established" + ); + } if left == right { return Ordering::Equal; @@ -128,7 +133,9 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::SubclassOf(_), _) => Ordering::Less, (_, Type::SubclassOf(_)) => Ordering::Greater, - (Type::TypeIs(left), Type::TypeIs(right)) => typeis_ordering(db, *left, *right), + (Type::TypeIs(left), Type::TypeIs(right)) => { + typeis_ordering(db, *left, *right, cycle_recovery) + } (Type::TypeIs(_), _) => Ordering::Less, (_, Type::TypeIs(_)) => Ordering::Greater, @@ -239,13 +246,15 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( return left_negative.len().cmp(&right_negative.len()); } for (left, right) in left_positive.iter().zip(right_positive) { - let ordering = union_or_intersection_elements_ordering(db, left, right); + let ordering = + union_or_intersection_elements_ordering(db, left, right, cycle_recovery); if ordering != Ordering::Equal { return ordering; } } for (left, right) in left_negative.iter().zip(right_negative) { - let ordering = union_or_intersection_elements_ordering(db, left, right); + let ordering = + union_or_intersection_elements_ordering(db, left, right, cycle_recovery); if ordering != Ordering::Equal { return ordering; } @@ -286,17 +295,26 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering /// * Boundness: Unbound precedes bound /// * Symbol name: String comparison /// * Guarded type: [`union_or_intersection_elements_ordering`] -fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering { +fn typeis_ordering( + db: &dyn Db, + left: TypeIsType, + right: TypeIsType, + cycle_recovery: bool, +) -> Ordering { let (left_ty, right_ty) = (left.return_type(db), right.return_type(db)); match (left.place_info(db), right.place_info(db)) { (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, - (None, None) => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), + (None, None) => { + union_or_intersection_elements_ordering(db, &left_ty, &right_ty, cycle_recovery) + } (Some(_), Some(_)) => match left.place_name(db).cmp(&right.place_name(db)) { - Ordering::Equal => union_or_intersection_elements_ordering(db, &left_ty, &right_ty), + Ordering::Equal => { + union_or_intersection_elements_ordering(db, &left_ty, &right_ty, cycle_recovery) + } ordering => ordering, }, }