diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index 4f374ac754..8d722288e4 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -227,17 +227,22 @@ def _(literals_2: Literal[0, 1], b: bool, flag: bool): literals_16 = 4 * literals_4 + literals_4 # Literal[0, 1, .., 15] literals_64 = 4 * literals_16 + literals_4 # Literal[0, 1, .., 63] literals_128 = 2 * literals_64 + literals_2 # Literal[0, 1, .., 127] + literals_256 = 2 * literals_128 + literals_2 # Literal[0, 1, .., 255] - # Going beyond the MAX_UNION_LITERALS limit (currently 200): - literals_256 = 16 * literals_16 + literals_16 - reveal_type(literals_256) # revealed: int + # Going beyond the MAX_UNION_LITERALS limit (currently 512): + literals_512 = 2 * literals_256 + literals_2 # Literal[0, 1, .., 511] + reveal_type(literals_512 if flag else 512) # revealed: int # Going beyond the limit when another type is already part of the union bool_and_literals_128 = b if flag else literals_128 # bool | Literal[0, 1, ..., 127] literals_128_shifted = literals_128 + 128 # Literal[128, 129, ..., 255] + literals_256_shifted = literals_256 + 256 # Literal[256, 257, ..., 511] # Now union the two: - reveal_type(bool_and_literals_128 if flag else literals_128_shifted) # revealed: int + two = bool_and_literals_128 if flag else literals_128_shifted + # revealed: bool | Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + reveal_type(two) + reveal_type(two if flag else literals_256_shifted) # revealed: int ``` ## Simplifying gradually-equivalent types diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 820d31c1b0..e297ebf821 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -44,6 +44,7 @@ use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{imported_modules, place_table, semantic_index}; use crate::suppression::check_suppressions; use crate::types::bound_super::BoundSuperType; +use crate::types::builder::RecursivelyDefined; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::constraints::{ @@ -9668,6 +9669,7 @@ impl<'db> TypeVarInstance<'db> { .skip(1) .map(|arg| definition_expression_type(db, definition, arg)) .collect::>(), + RecursivelyDefined::No, ) } _ => return None, @@ -10120,6 +10122,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> { .iter() .map(|ty| ty.normalized_impl(db, visitor)) .collect::>(), + constraints.recursively_defined(db), )) } } @@ -10201,6 +10204,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> { .iter() .map(|ty| ty.materialize(db, materialization_kind, visitor)) .collect::>(), + RecursivelyDefined::No, )) } } @@ -13142,6 +13146,9 @@ pub struct UnionType<'db> { /// The union type includes values in any of these types. #[returns(deref)] pub elements: Box<[Type<'db>]>, + /// Whether the value pointed to by this type is recursively defined. + /// If `Yes`, union literal widening is performed early. + recursively_defined: RecursivelyDefined, } pub(crate) fn walk_union<'db, V: visitor::TypeVisitor<'db> + ?Sized>( @@ -13226,7 +13233,14 @@ impl<'db> UnionType<'db> { db: &'db dyn Db, transform_fn: impl FnMut(&Type<'db>) -> Type<'db>, ) -> Type<'db> { - Self::from_elements(db, self.elements(db).iter().map(transform_fn)) + self.elements(db) + .iter() + .map(transform_fn) + .fold(UnionBuilder::new(db), |builder, element| { + builder.add(element) + }) + .recursively_defined(self.recursively_defined(db)) + .build() } /// A fallible version of [`UnionType::map`]. @@ -13241,7 +13255,12 @@ impl<'db> UnionType<'db> { db: &'db dyn Db, transform_fn: impl FnMut(&Type<'db>) -> Option>, ) -> Option> { - Self::try_from_elements(db, self.elements(db).iter().map(transform_fn)) + let mut builder = UnionBuilder::new(db); + for element in self.elements(db).iter().map(transform_fn) { + builder = builder.add(element?); + } + builder = builder.recursively_defined(self.recursively_defined(db)); + Some(builder.build()) } pub(crate) fn to_instance(self, db: &'db dyn Db) -> Option> { @@ -13253,7 +13272,14 @@ impl<'db> UnionType<'db> { db: &'db dyn Db, mut f: impl FnMut(&Type<'db>) -> bool, ) -> Type<'db> { - Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty))) + self.elements(db) + .iter() + .filter(|ty| f(ty)) + .fold(UnionBuilder::new(db), |builder, element| { + builder.add(*element) + }) + .recursively_defined(self.recursively_defined(db)) + .build() } pub(crate) fn map_with_boundness( @@ -13288,7 +13314,9 @@ impl<'db> UnionType<'db> { Place::Undefined } else { Place::Defined( - builder.build(), + builder + .recursively_defined(self.recursively_defined(db)) + .build(), origin, if possibly_unbound { Definedness::PossiblyUndefined @@ -13336,7 +13364,9 @@ impl<'db> UnionType<'db> { Place::Undefined } else { Place::Defined( - builder.build(), + builder + .recursively_defined(self.recursively_defined(db)) + .build(), origin, if possibly_unbound { Definedness::PossiblyUndefined @@ -13371,6 +13401,7 @@ impl<'db> UnionType<'db> { .unpack_aliases(true), UnionBuilder::add, ) + .recursively_defined(self.recursively_defined(db)) .build() } @@ -13383,7 +13414,8 @@ impl<'db> UnionType<'db> { let mut builder = UnionBuilder::new(db) .order_elements(false) .unpack_aliases(false) - .cycle_recovery(true); + .cycle_recovery(true) + .recursively_defined(self.recursively_defined(db)); let mut empty = true; for ty in self.elements(db) { if nested { @@ -13398,6 +13430,7 @@ impl<'db> UnionType<'db> { // `Divergent` in a union type does not mean true divergence, so we skip it if not nested. // e.g. T | Divergent == T | (T | (T | (T | ...))) == T if ty == &div { + builder = builder.recursively_defined(RecursivelyDefined::Yes); continue; } builder = builder.add( diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 0618682837..64ca36010a 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -202,12 +202,30 @@ enum ReduceResult<'db> { Type(Type<'db>), } -// TODO increase this once we extend `UnionElement` throughout all union/intersection -// representations, so that we can make large unions of literals fast in all operations. -// -// For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number -// below 200, which is the salsa fixpoint iteration limit. -const MAX_UNION_LITERALS: usize = 190; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, get_size2::GetSize)] +pub enum RecursivelyDefined { + Yes, + No, +} + +impl RecursivelyDefined { + const fn is_yes(self) -> bool { + matches!(self, RecursivelyDefined::Yes) + } + + const fn or(self, other: RecursivelyDefined) -> RecursivelyDefined { + match (self, other) { + (RecursivelyDefined::Yes, _) | (_, RecursivelyDefined::Yes) => RecursivelyDefined::Yes, + _ => RecursivelyDefined::No, + } + } +} + +/// If the value ​​is defined recursively, widening is performed from fewer literal elements, resulting in faster convergence of the fixed-point iteration. +const MAX_RECURSIVE_UNION_LITERALS: usize = 10; +/// If the value ​​is defined non-recursively, the fixed-point iteration will converge in one go, +/// so in principle we can have as many literal elements as we want, but to avoid unintended huge computational loads, we limit it to 256. +const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256; pub(crate) struct UnionBuilder<'db> { elements: Vec>, @@ -217,6 +235,7 @@ pub(crate) struct UnionBuilder<'db> { // This is enabled when joining types in a `cycle_recovery` function. // Since a cycle cannot be created within a `cycle_recovery` function, execution of `is_redundant_with` is skipped. cycle_recovery: bool, + recursively_defined: RecursivelyDefined, } impl<'db> UnionBuilder<'db> { @@ -227,6 +246,7 @@ impl<'db> UnionBuilder<'db> { unpack_aliases: true, order_elements: false, cycle_recovery: false, + recursively_defined: RecursivelyDefined::No, } } @@ -248,6 +268,11 @@ impl<'db> UnionBuilder<'db> { self } + pub(crate) fn recursively_defined(mut self, val: RecursivelyDefined) -> Self { + self.recursively_defined = val; + self + } + pub(crate) fn is_empty(&self) -> bool { self.elements.is_empty() } @@ -258,6 +283,27 @@ impl<'db> UnionBuilder<'db> { self.elements.push(UnionElement::Type(Type::object())); } + fn widen_literal_types(&mut self, seen_aliases: &mut Vec>) { + let mut replace_with = vec![]; + for elem in &self.elements { + match elem { + UnionElement::IntLiterals(_) => { + replace_with.push(KnownClass::Int.to_instance(self.db)); + } + UnionElement::StringLiterals(_) => { + replace_with.push(KnownClass::Str.to_instance(self.db)); + } + UnionElement::BytesLiterals(_) => { + replace_with.push(KnownClass::Bytes.to_instance(self.db)); + } + UnionElement::Type(_) => {} + } + } + for ty in replace_with { + self.add_in_place_impl(ty, seen_aliases); + } + } + /// Adds a type to this union. pub(crate) fn add(mut self, ty: Type<'db>) -> Self { self.add_in_place(ty); @@ -270,6 +316,15 @@ impl<'db> UnionBuilder<'db> { } pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec>) { + let cycle_recovery = self.cycle_recovery; + let should_widen = |literals, recursively_defined: RecursivelyDefined| { + if recursively_defined.is_yes() && cycle_recovery { + literals >= MAX_RECURSIVE_UNION_LITERALS + } else { + literals >= MAX_NON_RECURSIVE_UNION_LITERALS + } + }; + match ty { Type::Union(union) => { let new_elements = union.elements(self.db); @@ -277,6 +332,20 @@ impl<'db> UnionBuilder<'db> { for element in new_elements { self.add_in_place_impl(*element, seen_aliases); } + self.recursively_defined = self + .recursively_defined + .or(union.recursively_defined(self.db)); + if self.cycle_recovery && self.recursively_defined.is_yes() { + let literals = self.elements.iter().fold(0, |acc, elem| match elem { + UnionElement::IntLiterals(literals) => acc + literals.len(), + UnionElement::StringLiterals(literals) => acc + literals.len(), + UnionElement::BytesLiterals(literals) => acc + literals.len(), + UnionElement::Type(_) => acc, + }); + if should_widen(literals, self.recursively_defined) { + self.widen_literal_types(seen_aliases); + } + } } // Adding `Never` to a union is a no-op. Type::Never => {} @@ -300,7 +369,7 @@ impl<'db> UnionBuilder<'db> { for (index, element) in self.elements.iter_mut().enumerate() { match element { UnionElement::StringLiterals(literals) => { - if literals.len() >= MAX_UNION_LITERALS { + if should_widen(literals.len(), self.recursively_defined) { let replace_with = KnownClass::Str.to_instance(self.db); self.add_in_place_impl(replace_with, seen_aliases); return; @@ -345,7 +414,7 @@ impl<'db> UnionBuilder<'db> { for (index, element) in self.elements.iter_mut().enumerate() { match element { UnionElement::BytesLiterals(literals) => { - if literals.len() >= MAX_UNION_LITERALS { + if should_widen(literals.len(), self.recursively_defined) { let replace_with = KnownClass::Bytes.to_instance(self.db); self.add_in_place_impl(replace_with, seen_aliases); return; @@ -390,7 +459,7 @@ impl<'db> UnionBuilder<'db> { for (index, element) in self.elements.iter_mut().enumerate() { match element { UnionElement::IntLiterals(literals) => { - if literals.len() >= MAX_UNION_LITERALS { + if should_widen(literals.len(), self.recursively_defined) { let replace_with = KnownClass::Int.to_instance(self.db); self.add_in_place_impl(replace_with, seen_aliases); return; @@ -585,6 +654,7 @@ impl<'db> UnionBuilder<'db> { _ => Some(Type::Union(UnionType::new( self.db, types.into_boxed_slice(), + self.recursively_defined, ))), } } @@ -696,6 +766,7 @@ impl<'db> IntersectionBuilder<'db> { enum_member_literals(db, instance.class_literal(db), None) .expect("Calling `enum_member_literals` on an enum class") .collect::>(), + RecursivelyDefined::No, )), seen_aliases, ) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index dd04b6d001..ccb25b66a2 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -50,6 +50,7 @@ use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, }; use crate::subscript::{PyIndex, PySlice}; +use crate::types::builder::RecursivelyDefined; use crate::types::call::bind::{CallableDescription, MatchingOverloadIndex}; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; @@ -3283,6 +3284,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { elts.iter() .map(|expr| self.infer_type_expression(expr)) .collect::>(), + RecursivelyDefined::No, )); self.store_expression_type(expr, ty); } diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index d906a472c3..4d15700163 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -416,7 +416,7 @@ impl<'db> SubclassOfInner<'db> { ) } Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { - let constraints = constraints + let constraints_types = constraints .elements(db) .iter() .map(|constraint| { @@ -425,7 +425,11 @@ impl<'db> SubclassOfInner<'db> { }) .collect::>(); - TypeVarBoundOrConstraints::Constraints(UnionType::new(db, constraints)) + TypeVarBoundOrConstraints::Constraints(UnionType::new( + db, + constraints_types, + constraints.recursively_defined(db), + )) } }) }); diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index d2b96f2849..6405ae3ae7 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -23,6 +23,7 @@ use itertools::{Either, EitherOrBoth, Itertools}; use crate::semantic_index::definition::Definition; use crate::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; +use crate::types::builder::RecursivelyDefined; use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::generics::InferableTypeVars; @@ -1458,7 +1459,7 @@ impl<'db> Tuple> { // those techniques ensure that union elements are deduplicated and unions are eagerly simplified // into other types where necessary. Here, however, we know that there are no duplicates // in this union, so it's probably more efficient to use `UnionType::new()` directly. - Type::Union(UnionType::new(db, elements)) + Type::Union(UnionType::new(db, elements, RecursivelyDefined::No)) }; TupleSpec::heterogeneous([