diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 0940ec1f6d..4d117ed4f7 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1266,7 +1266,7 @@ impl<'db> ClassLiteral<'db> { class_def_node.type_params.as_ref().map(|type_params| { let index = semantic_index(db, scope.file(db)); let definition = index.expect_single_definition(class_def_node); - GenericContext::from_type_params(db, index, definition, type_params) + GenericContext::from_type_params(db, index, definition, type_params, self.known(db)) }) } @@ -1290,6 +1290,7 @@ impl<'db> ClassLiteral<'db> { .iter() .copied() .filter(|ty| matches!(ty, Type::GenericAlias(_))), + self.known(db), ) } @@ -1334,8 +1335,7 @@ impl<'db> ClassLiteral<'db> { specialization: Option>, ) -> ClassType<'db> { self.apply_specialization(db, |generic_context| { - specialization - .unwrap_or_else(|| generic_context.default_specialization(db, self.known(db))) + specialization.unwrap_or_else(|| generic_context.default_specialization(db)) }) } @@ -1344,7 +1344,7 @@ impl<'db> ClassLiteral<'db> { /// applies the default specialization to the class's typevars. pub(crate) fn default_specialization(self, db: &'db dyn Db) -> ClassType<'db> { self.apply_specialization(db, |generic_context| { - generic_context.default_specialization(db, self.known(db)) + generic_context.default_specialization(db) }) } diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index b50ca3d996..d6e5383ecb 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -454,7 +454,6 @@ impl Display for DisplayGenericContext<'_> { let variables = self.generic_context.variables(self.db); let non_implicit_variables: Vec<_> = variables - .iter() .filter(|bound_typevar| !bound_typevar.typevar(self.db).is_implicit(self.db)) .collect(); diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 8a38ea0a84..4f811ae26d 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -340,7 +340,7 @@ impl<'db> OverloadLiteral<'db> { let definition = self.definition(db); let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { let index = semantic_index(db, scope.file(db)); - GenericContext::from_type_params(db, index, definition, type_params) + GenericContext::from_type_params(db, index, definition, type_params, None) }); let index = semantic_index(db, scope.file(db)); diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 473dc9e1f1..51ee8d5a95 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; +use itertools::Either; use ruff_db::parsed::ParsedModuleRef; use ruff_python_ast as ast; use rustc_hash::FxHashMap; @@ -12,7 +13,7 @@ use crate::types::class_base::ClassBase; use crate::types::infer::infer_definition_types; use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; -use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; +use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, HasRelationToVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeTransformer, @@ -82,6 +83,14 @@ pub(crate) fn bind_typevar<'db>( }) } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum GenericContextInner<'db> { + Tuple { + single_typevar: BoundTypeVarInstance<'db>, + }, + NonTuple(FxOrderSet>), +} + /// A list of formal type variables for a generic function, class, or type alias. /// /// TODO: Handle nested generic contexts better, with actual parent links to the lexically @@ -94,7 +103,7 @@ pub(crate) fn bind_typevar<'db>( #[derive(PartialOrd, Ord)] pub struct GenericContext<'db> { #[returns(ref)] - pub(crate) variables: FxOrderSet>, + pub(crate) inner: GenericContextInner<'db>, } pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( @@ -103,7 +112,7 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S visitor: &V, ) { for bound_typevar in context.variables(db) { - visitor.visit_bound_type_var_type(db, *bound_typevar); + visitor.visit_bound_type_var_type(db, bound_typevar); } } @@ -111,12 +120,25 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S impl get_size2::GetSize for GenericContext<'_> {} impl<'db> GenericContext<'db> { + pub(crate) fn variables( + self, + db: &'db dyn Db, + ) -> impl ExactSizeIterator> { + match self.inner(db) { + GenericContextInner::Tuple { single_typevar } => { + Either::Left(std::iter::once(*single_typevar)) + } + GenericContextInner::NonTuple(variables) => Either::Right(variables.iter().copied()), + } + } + /// Creates a generic context from a list of PEP-695 type parameters. pub(crate) fn from_type_params( db: &'db dyn Db, index: &'db SemanticIndex<'db>, binding_context: Definition<'db>, type_params_node: &ast::TypeParams, + known_class: Option, ) -> Self { let variables: FxOrderSet<_> = type_params_node .iter() @@ -124,7 +146,23 @@ impl<'db> GenericContext<'db> { Self::variable_from_type_param(db, index, binding_context, type_param) }) .collect(); - Self::new(db, variables) + + match known_class { + Some(KnownClass::Tuple) => { + assert_eq!( + variables.len(), + 1, + "Tuple should always have exactly one typevar" + ); + Self::new( + db, + GenericContextInner::Tuple { + single_typevar: variables[0], + }, + ) + } + _ => Self::new(db, GenericContextInner::NonTuple(variables)), + } } fn variable_from_type_param( @@ -174,7 +212,7 @@ impl<'db> GenericContext<'db> { if variables.is_empty() { return None; } - Some(Self::new(db, variables)) + Some(Self::new(db, GenericContextInner::NonTuple(variables))) } /// Creates a generic context from the legacy `TypeVar`s that appear in class's base class @@ -182,6 +220,7 @@ impl<'db> GenericContext<'db> { pub(crate) fn from_base_classes( db: &'db dyn Db, bases: impl Iterator>, + known_class: Option, ) -> Option { let mut variables = FxOrderSet::default(); for base in bases { @@ -190,7 +229,23 @@ impl<'db> GenericContext<'db> { if variables.is_empty() { return None; } - Some(Self::new(db, variables)) + let context = match known_class { + Some(KnownClass::Tuple) => { + assert_eq!( + variables.len(), + 1, + "Tuple should always have exactly one typevar" + ); + Self::new( + db, + GenericContextInner::Tuple { + single_typevar: variables[0], + }, + ) + } + _ => Self::new(db, GenericContextInner::NonTuple(variables)), + }; + Some(context) } pub(crate) fn len(self, db: &'db dyn Db) -> usize { @@ -200,8 +255,7 @@ impl<'db> GenericContext<'db> { pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { let parameters = Parameters::new( self.variables(db) - .iter() - .map(|typevar| Self::parameter_from_typevar(db, *typevar)), + .map(|typevar| Self::parameter_from_typevar(db, typevar)), ); Signature::new(parameters, None) } @@ -231,50 +285,54 @@ impl<'db> GenericContext<'db> { parameter } - pub(crate) fn default_specialization( - self, - db: &'db dyn Db, - known_class: Option, - ) -> Specialization<'db> { - let partial = self.specialize_partial(db, &vec![None; self.variables(db).len()]); - if known_class == Some(KnownClass::Tuple) { - Specialization::new( - db, - self, - partial.types(db), - Some(TupleType::homogeneous(db, Type::unknown())), - ) - } else { - partial - } + pub(crate) fn default_specialization(self, db: &'db dyn Db) -> Specialization<'db> { + self.specialize_partial(db, &vec![None; self.variables(db).len()]) } pub(crate) fn identity_specialization(self, db: &'db dyn Db) -> Specialization<'db> { - let types = self - .variables(db) - .iter() - .map(|typevar| Type::TypeVar(*typevar)) - .collect(); + let types = self.variables(db).map(Type::TypeVar).collect(); self.specialize(db, types) } pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> { let types = vec![Type::unknown(); self.variables(db).len()]; - self.specialize(db, types.into()) + self.specialize(db, types.into_boxed_slice()) } /// Returns a tuple type of the typevars introduced by this generic context. pub(crate) fn as_tuple(self, db: &'db dyn Db) -> Type<'db> { - Type::heterogeneous_tuple( - db, - self.variables(db) - .iter() - .map(|typevar| Type::TypeVar(*typevar)), - ) + Type::heterogeneous_tuple(db, self.variables(db).map(Type::TypeVar)) } pub(crate) fn is_subset_of(self, db: &'db dyn Db, other: GenericContext<'db>) -> bool { - self.variables(db).is_subset(other.variables(db)) + match (self.inner(db), other.inner(db)) { + (GenericContextInner::NonTuple(left), GenericContextInner::NonTuple(right)) => { + left.is_subset(right) + } + + ( + GenericContextInner::Tuple { + single_typevar: left, + }, + GenericContextInner::NonTuple(right), + ) => right.contains(left), + + ( + GenericContextInner::NonTuple(left), + GenericContextInner::Tuple { + single_typevar: right, + }, + ) => left.len() == 1 && left[0] == *right, + + ( + GenericContextInner::Tuple { + single_typevar: left, + }, + GenericContextInner::Tuple { + single_typevar: right, + }, + ) => left == right, + } } pub(crate) fn binds_typevar( @@ -283,9 +341,7 @@ impl<'db> GenericContext<'db> { typevar: TypeVarInstance<'db>, ) -> Option> { self.variables(db) - .iter() .find(|self_bound_typevar| self_bound_typevar.typevar(db) == typevar) - .copied() } /// Creates a specialization of this generic context. Panics if the length of `types` does not @@ -297,18 +353,25 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, types: Box<[Type<'db>]>, ) -> Specialization<'db> { - assert!(self.variables(db).len() == types.len()); - Specialization::new(db, self, types, None) + assert_eq!(self.variables(db).len(), types.len()); + debug_assert!( + matches!(self.inner(db), GenericContextInner::NonTuple(_)), + "Should never call `GenericContext::specialize` on a tuple context" + ); + Specialization::new(db, self, SpecializationInner::NonTuple(types)) } /// Creates a specialization of this generic context for the `tuple` class. pub(crate) fn specialize_tuple( self, db: &'db dyn Db, - element_type: Type<'db>, tuple: TupleType<'db>, ) -> Specialization<'db> { - Specialization::new(db, self, Box::from([element_type]), Some(tuple)) + debug_assert!( + matches!(self.inner(db), GenericContextInner::Tuple { .. }), + "Should never call `GenericContext::specialize_tuple` on a non-tuple context" + ); + Specialization::new(db, self, SpecializationInner::Tuple(tuple)) } /// Creates a specialization of this generic context. Panics if the length of `types` does not @@ -319,8 +382,14 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, types: &[Option>], ) -> Specialization<'db> { + if let GenericContextInner::Tuple { .. } = self.inner(db) { + assert_eq!(types.len(), 1); + let ty = types[0].unwrap_or_else(Type::unknown); + return self.specialize_tuple(db, TupleType::homogeneous(db, ty)); + } + let variables = self.variables(db); - assert!(variables.len() == types.len()); + assert_eq!(variables.len(), types.len()); // Typevars can have other typevars as their default values, e.g. // @@ -353,20 +422,40 @@ impl<'db> GenericContext<'db> { expanded[idx] = default; } - Specialization::new(db, self, expanded.into_boxed_slice(), None) + Specialization::new( + db, + self, + SpecializationInner::NonTuple(expanded.into_boxed_slice()), + ) } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - let variables: FxOrderSet<_> = self - .variables(db) - .iter() - .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)) - .collect(); - Self::new(db, variables) + let inner = match self.inner(db) { + GenericContextInner::Tuple { single_typevar } => { + let single_typevar = single_typevar.normalized_impl(db, visitor); + GenericContextInner::Tuple { single_typevar } + } + GenericContextInner::NonTuple(variables) => { + let variables: FxOrderSet<_> = variables + .into_iter() + .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)) + .collect(); + GenericContextInner::NonTuple(variables) + } + }; + + Self::new(db, inner) } - fn heap_size((variables,): &(FxOrderSet>,)) -> usize { - ruff_memory_usage::order_set_heap_size(variables) + fn heap_size((inner,): &(GenericContextInner<'db>,)) -> usize { + match inner { + GenericContextInner::Tuple { single_typevar } => { + ruff_memory_usage::heap_size(single_typevar) + } + GenericContextInner::NonTuple(variables) => { + ruff_memory_usage::order_set_heap_size(variables) + } + } } } @@ -392,7 +481,7 @@ impl std::fmt::Display for LegacyGenericBase { } #[derive(Debug, Clone, PartialEq, Eq, Hash, get_size2::GetSize)] -enum SpecializationInner<'db> { +pub enum SpecializationInner<'db> { Tuple(TupleType<'db>), NonTuple(Box<[Type<'db>]>), } @@ -423,7 +512,7 @@ pub(super) fn walk_specialization<'db, V: super::visitor::TypeVisitor<'db> + ?Si } impl<'db> Specialization<'db> { - pub(crate) fn types(self, db: &'db dyn Db) -> &[Type<'db>] { + pub(crate) fn types(self, db: &'db dyn Db) -> &'db [Type<'db>] { #[salsa::tracked(returns(ref))] fn homogeneous_element_type<'db>(db: &'db dyn Db, tuple: TupleType<'db>) -> Type<'db> { tuple.tuple(db).homogeneous_element_type(db) @@ -452,11 +541,18 @@ impl<'db> Specialization<'db> { db: &'db dyn Db, bound_typevar: BoundTypeVarInstance<'db>, ) -> Option> { - let index = self - .generic_context(db) - .variables(db) - .get_index_of(&bound_typevar)?; - self.types(db).get(index).copied() + match self.generic_context(db).inner(db) { + GenericContextInner::Tuple { single_typevar } => (*single_typevar == bound_typevar) + .then(|| { + let types = self.types(db); + assert_eq!(types.len(), 1); + types[0] + }), + GenericContextInner::NonTuple(variables) => { + let index = variables.get_index_of(&bound_typevar)?; + self.types(db).get(index).copied() + } + } } /// Applies a specialization to this specialization. This is used, for instance, when a generic @@ -526,7 +622,13 @@ impl<'db> Specialization<'db> { /// Panics if the two specializations are not for the same generic context. pub(crate) fn combine(self, db: &'db dyn Db, other: Self) -> Self { let generic_context = self.generic_context(db); - assert!(other.generic_context(db) == generic_context); + + assert_eq!(other.generic_context(db), generic_context); + debug_assert!( + matches!(self.inner(db), SpecializationInner::NonTuple(_)), + "The tuple constructor is special-cased everywhere" + ); + // TODO special-casing Unknown to mean "no mapping" is not right here, and can give // confusing/wrong results in cases where there was a mapping found for a typevar, and it // was of type Unknown. We should probably add a bitset or similar to Specialization that @@ -540,8 +642,12 @@ impl<'db> Specialization<'db> { _ => UnionType::from_elements(db, [self_type, other_type]), }) .collect(); - // TODO: Combine the tuple specs too - Specialization::new(db, self.generic_context(db), types, None) + + Specialization::new( + db, + self.generic_context(db), + SpecializationInner::NonTuple(types), + ) } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { @@ -564,26 +670,32 @@ impl<'db> Specialization<'db> { } pub(super) fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { - let types: Box<[_]> = self - .generic_context(db) - .variables(db) - .into_iter() - .zip(self.types(db)) - .map(|(bound_typevar, vartype)| { - let variance = match bound_typevar.typevar(db).variance(db) { - TypeVarVariance::Invariant => TypeVarVariance::Invariant, - TypeVarVariance::Covariant => variance, - TypeVarVariance::Contravariant => variance.flip(), - TypeVarVariance::Bivariant => unreachable!(), - }; - vartype.materialize(db, variance) - }) - .collect(); - let tuple_inner = self.tuple_inner(db).and_then(|tuple| { - // Tuples are immutable, so tuple element types are always in covariant position. - tuple.materialize(db, variance) - }); - Specialization::new(db, self.generic_context(db), types, tuple_inner) + let inner = match self.inner(db) { + SpecializationInner::Tuple(tuple) => SpecializationInner::Tuple( + tuple + .materialize(db, variance) + .unwrap_or_else(|| TupleType::empty(db)), + ), + SpecializationInner::NonTuple(types) => { + let types: Box<[_]> = self + .generic_context(db) + .variables(db) + .zip(types) + .map(|(bound_typevar, vartype)| { + let variance = match bound_typevar.typevar(db).variance(db) { + TypeVarVariance::Invariant => TypeVarVariance::Invariant, + TypeVarVariance::Covariant => variance, + TypeVarVariance::Contravariant => variance.flip(), + TypeVarVariance::Bivariant => unreachable!(), + }; + vartype.materialize(db, variance) + }) + .collect(); + SpecializationInner::NonTuple(types) + } + }; + + Specialization::new(db, self.generic_context(db), inner) } pub(crate) fn has_relation_to_impl( @@ -598,12 +710,14 @@ impl<'db> Specialization<'db> { return false; } - if let (Some(self_tuple), Some(other_tuple)) = (self.tuple_inner(db), other.tuple_inner(db)) + if let (SpecializationInner::Tuple(left), SpecializationInner::Tuple(right)) = + (self.inner(db), other.inner(db)) { - return self_tuple.has_relation_to_impl(db, other_tuple, relation, visitor); + return left.has_relation_to_impl(db, *right, relation, visitor); } - for ((bound_typevar, self_type), other_type) in (generic_context.variables(db).into_iter()) + for ((bound_typevar, self_type), other_type) in generic_context + .variables(db) .zip(self.types(db)) .zip(other.types(db)) { @@ -650,7 +764,8 @@ impl<'db> Specialization<'db> { return false; } - for ((bound_typevar, self_type), other_type) in (generic_context.variables(db).into_iter()) + for ((bound_typevar, self_type), other_type) in generic_context + .variables(db) .zip(self.types(db)) .zip(other.types(db)) { @@ -724,11 +839,19 @@ impl<'db> PartialSpecialization<'_, 'db> { db: &'db dyn Db, bound_typevar: BoundTypeVarInstance<'db>, ) -> Option> { - let index = self - .generic_context - .variables(db) - .get_index_of(&bound_typevar)?; - self.types.get(index).copied() + match self.generic_context.inner(db) { + GenericContextInner::Tuple { single_typevar } => { + if bound_typevar == *single_typevar { + self.types.first().copied() + } else { + None + } + } + GenericContextInner::NonTuple(variables) => { + let index = variables.get_index_of(&bound_typevar)?; + self.types.get(index).copied() + } + } } pub(crate) fn to_owned(&self) -> PartialSpecialization<'db, 'db> { @@ -773,18 +896,30 @@ impl<'db> SpecializationBuilder<'db> { } pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> { - let types: Box<[_]> = generic_context - .variables(self.db) - .iter() - .map(|variable| { - self.types - .get(variable) + let inner = match generic_context.inner(self.db) { + GenericContextInner::Tuple { single_typevar } => { + let ty = self + .types + .get(single_typevar) .copied() - .unwrap_or(variable.default_type(self.db).unwrap_or(Type::unknown())) - }) - .collect(); - // TODO Infer the tuple spec for a tuple type - Specialization::new(self.db, generic_context, types, None) + .unwrap_or_else(Type::unknown); + SpecializationInner::Tuple(TupleType::homogeneous(self.db, ty)) + } + GenericContextInner::NonTuple(variables) => { + let types: Box<[_]> = variables + .iter() + .map(|bound_typevar| { + self.types.get(bound_typevar).copied().unwrap_or( + bound_typevar + .default_type(self.db) + .unwrap_or_else(Type::unknown), + ) + }) + .collect(); + SpecializationInner::NonTuple(types) + } + }; + Specialization::new(self.db, generic_context, inner) } fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 8bba369c3b..b37ba85bdb 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -110,7 +110,7 @@ use crate::types::enums::is_enum_class; use crate::types::function::{ FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral, }; -use crate::types::generics::{GenericContext, bind_typevar}; +use crate::types::generics::{GenericContext, GenericContextInner, bind_typevar}; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::{CallableSignature, Signature}; @@ -9061,7 +9061,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } }) .collect(); - typevars.map(|typevars| GenericContext::new(self.db(), typevars)) + typevars + .map(|typevars| GenericContext::new(self.db(), GenericContextInner::NonTuple(typevars))) } fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 6058b4e49a..8fa835de0f 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -131,16 +131,6 @@ pub struct TupleType<'db> { pub(crate) tuple: TupleSpec<'db>, } -pub(super) fn walk_tuple_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( - db: &'db dyn Db, - tuple: TupleType<'db>, - visitor: &V, -) { - for element in tuple.tuple(db).all_elements() { - visitor.visit_type(db, *element); - } -} - // The Salsa heap is tracked separately. impl get_size2::GetSize for TupleType<'_> {} @@ -206,10 +196,9 @@ impl<'db> TupleType<'db> { tuple_class.apply_specialization(db, |generic_context| { if generic_context.variables(db).len() == 1 { - let element_type = self.tuple(db).homogeneous_element_type(db); - generic_context.specialize_tuple(db, element_type, self) + generic_context.specialize_tuple(db, self) } else { - generic_context.default_specialization(db, Some(KnownClass::Tuple)) + generic_context.default_specialization(db) } }) } @@ -290,9 +279,9 @@ fn to_class_type_cycle_initial<'db>(db: &'db dyn Db, self_: TupleType<'db>) -> C tuple_class.apply_specialization(db, |generic_context| { if generic_context.variables(db).len() == 1 { - generic_context.specialize_tuple(db, Type::Never, self_) + generic_context.specialize_tuple(db, self_) } else { - generic_context.default_specialization(db, Some(KnownClass::Tuple)) + generic_context.default_specialization(db) } }) }