diff --git a/crates/ruff_memory_usage/src/lib.rs b/crates/ruff_memory_usage/src/lib.rs index d2002bc6bb..e75c75808e 100644 --- a/crates/ruff_memory_usage/src/lib.rs +++ b/crates/ruff_memory_usage/src/lib.rs @@ -1,7 +1,7 @@ use std::sync::{LazyLock, Mutex}; use get_size2::{GetSize, StandardTracker}; -use ordermap::OrderSet; +use ordermap::{OrderMap, OrderSet}; /// Returns the memory usage of the provided object, using a global tracker to avoid /// double-counting shared objects. @@ -18,3 +18,11 @@ pub fn heap_size(value: &T) -> usize { pub fn order_set_heap_size(set: &OrderSet) -> usize { (set.capacity() * T::get_stack_size()) + set.iter().map(heap_size).sum::() } + +/// An implementation of [`GetSize::get_heap_size`] for [`OrderMap`]. +pub fn order_map_heap_size(map: &OrderMap) -> usize { + (map.capacity() * (K::get_stack_size() + V::get_stack_size())) + + (map.iter()) + .map(|(k, v)| heap_size(k) + heap_size(v)) + .sum::() +} diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index fe5bcc5741..bb359c0b54 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5456,28 +5456,19 @@ impl<'db> Type<'db> { } } - let new_specialization = new_call_outcome - .and_then(Result::ok) - .as_ref() - .and_then(Bindings::single_element) - .into_iter() - .flat_map(CallableBinding::matching_overloads) - .next() - .and_then(|(_, binding)| binding.inherited_specialization()) - .filter(|specialization| { - Some(specialization.generic_context(db)) == generic_context - }); - let init_specialization = init_call_outcome - .and_then(Result::ok) - .as_ref() - .and_then(Bindings::single_element) - .into_iter() - .flat_map(CallableBinding::matching_overloads) - .next() - .and_then(|(_, binding)| binding.inherited_specialization()) - .filter(|specialization| { - Some(specialization.generic_context(db)) == generic_context - }); + let specialize_constructor = |outcome: Option>| { + let (_, binding) = outcome + .as_ref()? + .single_element()? + .matching_overloads() + .next()?; + binding.specialization()?.restrict(db, generic_context?) + }; + + let new_specialization = + specialize_constructor(new_call_outcome.and_then(Result::ok)); + let init_specialization = + specialize_constructor(init_call_outcome.and_then(Result::ok)); let specialization = combine_specializations(db, new_specialization, init_specialization); let specialized = specialization @@ -6768,13 +6759,11 @@ impl<'db> TypeMapping<'_, 'db> { db, context .variables(db) - .iter() - .filter(|var| !var.typevar(db).is_self(db)) - .copied(), + .filter(|var| !var.typevar(db).is_self(db)), ), TypeMapping::ReplaceSelf { new_upper_bound } => GenericContext::from_typevar_instances( db, - context.variables(db).iter().map(|typevar| { + context.variables(db).map(|typevar| { if typevar.typevar(db).is_self(db) { BoundTypeVarInstance::synthetic_self( db, @@ -6782,7 +6771,7 @@ impl<'db> TypeMapping<'_, 'db> { typevar.binding_context(db), ) } else { - *typevar + typevar } }), ), diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 841159221b..0f2853a5be 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, - TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType, + TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; @@ -1701,10 +1701,6 @@ impl<'db> CallableBinding<'db> { parameter_type = parameter_type.apply_specialization(db, specialization); } - if let Some(inherited_specialization) = overload.inherited_specialization { - parameter_type = - parameter_type.apply_specialization(db, inherited_specialization); - } union_parameter_types[parameter_index.saturating_sub(skipped_parameters)] .add_in_place(parameter_type); } @@ -1983,7 +1979,7 @@ impl<'db> CallableBinding<'db> { for overload in overloads.iter().take(MAXIMUM_OVERLOADS) { diag.info(format_args!( " {}", - overload.signature(context.db(), None).display(context.db()) + overload.signature(context.db()).display(context.db()) )); } if overloads.len() > MAXIMUM_OVERLOADS { @@ -2444,7 +2440,6 @@ struct ArgumentTypeChecker<'a, 'db> { errors: &'a mut Vec>, specialization: Option>, - inherited_specialization: Option>, } impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { @@ -2466,7 +2461,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { call_expression_tcx, errors, specialization: None, - inherited_specialization: None, } } @@ -2498,9 +2492,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } fn infer_specialization(&mut self) { - if self.signature.generic_context.is_none() - && self.signature.inherited_generic_context.is_none() - { + if self.signature.generic_context.is_none() { return; } @@ -2542,14 +2534,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } self.specialization = self.signature.generic_context.map(|gc| builder.build(gc)); - self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| { - // The inherited generic context is used when inferring the specialization of a generic - // class from a constructor call. In this case (only), we promote any typevars that are - // inferred as a literal to the corresponding instance type. - builder - .build(gc) - .apply_type_mapping(self.db, &TypeMapping::PromoteLiterals) - }); } fn check_argument_type( @@ -2566,11 +2550,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_type = argument_type.apply_specialization(self.db, specialization); expected_ty = expected_ty.apply_specialization(self.db, specialization); } - if let Some(inherited_specialization) = self.inherited_specialization { - argument_type = - argument_type.apply_specialization(self.db, inherited_specialization); - expected_ty = expected_ty.apply_specialization(self.db, inherited_specialization); - } // This is one of the few places where we want to check if there's _any_ specialization // where assignability holds; normally we want to check that assignability holds for // _all_ specializations. @@ -2742,8 +2721,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } - fn finish(self) -> (Option>, Option>) { - (self.specialization, self.inherited_specialization) + fn finish(self) -> Option> { + self.specialization } } @@ -2807,10 +2786,6 @@ pub(crate) struct Binding<'db> { /// The specialization that was inferred from the argument types, if the callable is generic. specialization: Option>, - /// The specialization that was inferred for a class method's containing generic class, if it - /// is being used to infer a specialization for the class. - inherited_specialization: Option>, - /// Information about which parameter(s) each argument was matched with, in argument source /// order. argument_matches: Box<[MatchedArgument<'db>]>, @@ -2835,7 +2810,6 @@ impl<'db> Binding<'db> { signature_type, return_ty: Type::unknown(), specialization: None, - inherited_specialization: None, argument_matches: Box::from([]), variadic_argument_matched_to_variadic_parameter: false, parameter_tys: Box::from([]), @@ -2906,15 +2880,10 @@ impl<'db> Binding<'db> { checker.infer_specialization(); checker.check_argument_types(); - (self.specialization, self.inherited_specialization) = checker.finish(); + self.specialization = checker.finish(); if let Some(specialization) = self.specialization { self.return_ty = self.return_ty.apply_specialization(db, specialization); } - if let Some(inherited_specialization) = self.inherited_specialization { - self.return_ty = self - .return_ty - .apply_specialization(db, inherited_specialization); - } } pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { @@ -2925,8 +2894,8 @@ impl<'db> Binding<'db> { self.return_ty } - pub(crate) fn inherited_specialization(&self) -> Option> { - self.inherited_specialization + pub(crate) fn specialization(&self) -> Option> { + self.specialization } /// Returns the bound types for each parameter, in parameter source order, or `None` if no @@ -2988,7 +2957,6 @@ impl<'db> Binding<'db> { BindingSnapshot { return_ty: self.return_ty, specialization: self.specialization, - inherited_specialization: self.inherited_specialization, argument_matches: self.argument_matches.clone(), parameter_tys: self.parameter_tys.clone(), errors: self.errors.clone(), @@ -2999,7 +2967,6 @@ impl<'db> Binding<'db> { let BindingSnapshot { return_ty, specialization, - inherited_specialization, argument_matches, parameter_tys, errors, @@ -3007,7 +2974,6 @@ impl<'db> Binding<'db> { self.return_ty = return_ty; self.specialization = specialization; - self.inherited_specialization = inherited_specialization; self.argument_matches = argument_matches; self.parameter_tys = parameter_tys; self.errors = errors; @@ -3027,7 +2993,6 @@ impl<'db> Binding<'db> { fn reset(&mut self) { self.return_ty = Type::unknown(); self.specialization = None; - self.inherited_specialization = None; self.argument_matches = Box::from([]); self.parameter_tys = Box::from([]); self.errors.clear(); @@ -3038,7 +3003,6 @@ impl<'db> Binding<'db> { struct BindingSnapshot<'db> { return_ty: Type<'db>, specialization: Option>, - inherited_specialization: Option>, argument_matches: Box<[MatchedArgument<'db>]>, parameter_tys: Box<[Option>]>, errors: Vec>, @@ -3078,7 +3042,6 @@ impl<'db> CallableBindingSnapshot<'db> { // ... and update the snapshot with the current state of the binding. snapshot.return_ty = binding.return_ty; snapshot.specialization = binding.specialization; - snapshot.inherited_specialization = binding.inherited_specialization; snapshot .argument_matches .clone_from(&binding.argument_matches); @@ -3373,7 +3336,7 @@ impl<'db> BindingError<'db> { } diag.info(format_args!( " {}", - overload.signature(context.db(), None).display(context.db()) + overload.signature(context.db()).display(context.db()) )); } if overloads.len() > MAXIMUM_OVERLOADS { diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index a6eb27f6e1..ae44c0cebd 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -324,7 +324,6 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> { specialization .generic_context(db) .variables(db) - .iter() .zip(specialization.types(db)) .map(|(generic_typevar, ty)| { if let Some(explicit_variance) = @@ -346,7 +345,7 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> { let typevar_variance_in_substituted_type = ty.variance_of(db, typevar); origin .with_polarity(typevar_variance_in_substituted_type) - .variance_of(db, *generic_typevar) + .variance_of(db, generic_typevar) } }), ) @@ -1013,8 +1012,7 @@ impl<'db> ClassType<'db> { let synthesized_dunder = CallableType::function_like( db, - Signature::new(parameters, None) - .with_inherited_generic_context(inherited_generic_context), + Signature::new_generic(inherited_generic_context, parameters, None), ); Place::bound(synthesized_dunder).into() @@ -1454,6 +1452,16 @@ impl<'db> ClassLiteral<'db> { ) } + /// Returns the generic context that should be inherited by any constructor methods of this + /// class. + /// + /// When inferring a specialization of the class's generic context from a constructor call, we + /// promote any typevars that are inferred as a literal to the corresponding instance type. + fn inherited_generic_context(self, db: &'db dyn Db) -> Option> { + self.generic_context(db) + .map(|generic_context| generic_context.promote_literals(db)) + } + fn file(self, db: &dyn Db) -> File { self.body_scope(db).file(db) } @@ -1996,7 +2004,7 @@ impl<'db> ClassLiteral<'db> { lookup_result = lookup_result.or_else(|lookup_error| { lookup_error.or_fall_back_to( db, - class.own_class_member(db, self.generic_context(db), name), + class.own_class_member(db, self.inherited_generic_context(db), name), ) }); } @@ -2246,8 +2254,14 @@ impl<'db> ClassLiteral<'db> { // so that the keyword-only parameters appear after positional parameters. parameters.sort_by_key(Parameter::is_keyword_only); - let mut signature = Signature::new(Parameters::new(parameters), return_ty); - signature.inherited_generic_context = self.generic_context(db); + let signature = match name { + "__new__" | "__init__" => Signature::new_generic( + self.inherited_generic_context(db), + Parameters::new(parameters), + return_ty, + ), + _ => Signature::new(Parameters::new(parameters), return_ty), + }; Some(CallableType::function_like(db, signature)) }; @@ -2295,7 +2309,7 @@ impl<'db> ClassLiteral<'db> { KnownClass::NamedTupleFallback .to_class_literal(db) .into_class_literal()? - .own_class_member(db, self.generic_context(db), None, name) + .own_class_member(db, self.inherited_generic_context(db), None, name) .place .ignore_possibly_unbound() .map(|ty| { @@ -5421,7 +5435,7 @@ enum SlotsKind { impl SlotsKind { fn from(db: &dyn Db, base: ClassLiteral) -> Self { let Place::Type(slots_ty, bound) = base - .own_class_member(db, base.generic_context(db), None, "__slots__") + .own_class_member(db, base.inherited_generic_context(db), None, "__slots__") .place else { return Self::NotSpecified; diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 4bf9c235d2..f20ebefd21 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -654,7 +654,7 @@ pub(crate) struct DisplayOverloadLiteral<'db> { impl Display for DisplayOverloadLiteral<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let signature = self.literal.signature(self.db, None); + let signature = self.literal.signature(self.db); let type_parameters = DisplayOptionalGenericContext { generic_context: signature.generic_context.as_ref(), db: self.db, @@ -832,7 +832,6 @@ impl Display for DisplayGenericContext<'_> { let variables = self.generic_context.variables(self.db); let non_implicit_variables: Vec<_> = variables - .iter() .filter(|bound_typevar| !bound_typevar.typevar(self.db).is_self(self.db)) .collect(); @@ -852,6 +851,10 @@ impl Display for DisplayGenericContext<'_> { } impl<'db> Specialization<'db> { + pub fn display(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> { + self.display_short(db, TupleSpecialization::No, DisplaySettings::default()) + } + /// Renders the specialization as it would appear in a subscript expression, e.g. `[int, str]`. pub fn display_short( &'db self, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 406baaf136..2b06665452 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -72,7 +72,7 @@ use crate::types::diagnostic::{ report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface, report_runtime_check_against_non_runtime_checkable_protocol, }; -use crate::types::generics::{GenericContext, walk_generic_context}; +use crate::types::generics::GenericContext; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; @@ -338,11 +338,7 @@ impl<'db> OverloadLiteral<'db> { /// calling query is not in the same file as this function is defined in, then this will create /// a cross-module dependency directly on the full AST which will lead to cache /// over-invalidation. - pub(crate) fn signature( - self, - db: &'db dyn Db, - inherited_generic_context: Option>, - ) -> Signature<'db> { + pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { /// `self` or `cls` can be implicitly positional-only if: /// - It is a method AND /// - No parameters in the method use PEP-570 syntax AND @@ -420,7 +416,6 @@ impl<'db> OverloadLiteral<'db> { Signature::from_function( db, generic_context, - inherited_generic_context, definition, function_stmt_node, is_generator, @@ -484,58 +479,13 @@ impl<'db> OverloadLiteral<'db> { #[derive(PartialOrd, Ord)] pub struct FunctionLiteral<'db> { pub(crate) last_definition: OverloadLiteral<'db>, - - /// The inherited generic context, if this function is a constructor method (`__new__` or - /// `__init__`) being used to infer the specialization of its generic class. If any of the - /// method's overloads are themselves generic, this is in addition to those per-overload - /// generic contexts (which are created lazily in [`OverloadLiteral::signature`]). - /// - /// If the function is not a constructor method, this field will always be `None`. - /// - /// If the function is a constructor method, we will end up creating two `FunctionLiteral` - /// instances for it. The first is created in [`TypeInferenceBuilder`][infer] when we encounter - /// the function definition during type inference. At this point, we don't yet know if the - /// function is a constructor method, so we create a `FunctionLiteral` with `None` for this - /// field. - /// - /// If at some point we encounter a call expression, which invokes the containing class's - /// constructor, as will create a _new_ `FunctionLiteral` instance for the function, with this - /// field [updated][] to contain the containing class's generic context. - /// - /// [infer]: crate::types::infer::TypeInferenceBuilder::infer_function_definition - /// [updated]: crate::types::class::ClassLiteral::own_class_member - inherited_generic_context: Option>, } // The Salsa heap is tracked separately. impl get_size2::GetSize for FunctionLiteral<'_> {} -fn walk_function_literal<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( - db: &'db dyn Db, - function: FunctionLiteral<'db>, - visitor: &V, -) { - if let Some(context) = function.inherited_generic_context(db) { - walk_generic_context(db, context, visitor); - } -} - #[salsa::tracked] impl<'db> FunctionLiteral<'db> { - fn with_inherited_generic_context( - self, - db: &'db dyn Db, - inherited_generic_context: GenericContext<'db>, - ) -> Self { - // A function cannot inherit more than one generic context from its containing class. - debug_assert!(self.inherited_generic_context(db).is_none()); - Self::new( - db, - self.last_definition(db), - Some(inherited_generic_context), - ) - } - fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { // All of the overloads of a function literal should have the same name. self.last_definition(db).name(db) @@ -626,21 +576,14 @@ impl<'db> FunctionLiteral<'db> { fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> { // We only include an implementation (i.e. a definition not decorated with `@overload`) if // it's the only definition. - let inherited_generic_context = self.inherited_generic_context(db); let (overloads, implementation) = self.overloads_and_implementation(db); if let Some(implementation) = implementation { if overloads.is_empty() { - return CallableSignature::single( - implementation.signature(db, inherited_generic_context), - ); + return CallableSignature::single(implementation.signature(db)); } } - CallableSignature::from_overloads( - overloads - .iter() - .map(|overload| overload.signature(db, inherited_generic_context)), - ) + CallableSignature::from_overloads(overloads.iter().map(|overload| overload.signature(db))) } /// Typed externally-visible signature of the last overload or implementation of this function. @@ -652,16 +595,7 @@ impl<'db> FunctionLiteral<'db> { /// a cross-module dependency directly on the full AST which will lead to cache /// over-invalidation. fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> { - let inherited_generic_context = self.inherited_generic_context(db); - self.last_definition(db) - .signature(db, inherited_generic_context) - } - - fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - let context = self - .inherited_generic_context(db) - .map(|ctx| ctx.normalized_impl(db, visitor)); - Self::new(db, self.last_definition(db), context) + self.last_definition(db).signature(db) } } @@ -695,7 +629,6 @@ pub(super) fn walk_function_type<'db, V: super::visitor::TypeVisitor<'db> + ?Siz function: FunctionType<'db>, visitor: &V, ) { - walk_function_literal(db, function.literal(db), visitor); if let Some(callable_signature) = function.updated_signature(db) { for signature in &callable_signature.overloads { walk_signature(db, signature, visitor); @@ -713,23 +646,18 @@ impl<'db> FunctionType<'db> { db: &'db dyn Db, inherited_generic_context: GenericContext<'db>, ) -> Self { - let literal = self - .literal(db) + let updated_signature = self + .signature(db) + .with_inherited_generic_context(db, inherited_generic_context); + let updated_last_definition_signature = self + .last_definition_signature(db) + .clone() .with_inherited_generic_context(db, inherited_generic_context); - let updated_signature = self.updated_signature(db).map(|signature| { - signature.with_inherited_generic_context(Some(inherited_generic_context)) - }); - let updated_last_definition_signature = - self.updated_last_definition_signature(db).map(|signature| { - signature - .clone() - .with_inherited_generic_context(Some(inherited_generic_context)) - }); Self::new( db, - literal, - updated_signature, - updated_last_definition_signature, + self.literal(db), + Some(updated_signature), + Some(updated_last_definition_signature), ) } @@ -764,8 +692,7 @@ impl<'db> FunctionType<'db> { let last_definition = literal .last_definition(db) .with_dataclass_transformer_params(db, params); - let literal = - FunctionLiteral::new(db, last_definition, literal.inherited_generic_context(db)); + let literal = FunctionLiteral::new(db, last_definition); Self::new(db, literal, None, None) } @@ -1036,7 +963,7 @@ impl<'db> FunctionType<'db> { } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - let literal = self.literal(db).normalized_impl(db, visitor); + let literal = self.literal(db); let updated_signature = self .updated_signature(db) .map(|signature| signature.normalized_impl(db, visitor)); diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 21e387017b..55b8a26638 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -19,7 +19,7 @@ use crate::types::{ NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, }; -use crate::{Db, FxOrderSet}; +use crate::{Db, FxOrderMap, FxOrderSet}; /// Returns an iterator of any generic context introduced by the given scope or any enclosing /// scope. @@ -137,19 +137,28 @@ pub(crate) fn typing_self<'db>( .map(typevar_to_type) } +#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, get_size2::GetSize)] +pub struct GenericContextTypeVarOptions { + should_promote_literals: bool, +} + +impl GenericContextTypeVarOptions { + fn promote_literals(mut self) -> Self { + self.should_promote_literals = true; + self + } +} + /// A list of formal type variables for a generic function, class, or type alias. /// -/// TODO: Handle nested generic contexts better, with actual parent links to the lexically -/// containing context. -/// /// # Ordering /// Ordering is based on the context's salsa-assigned id and not on its values. /// The id may change between runs, or when the context was garbage collected and recreated. -#[salsa::interned(debug, heap_size=GenericContext::heap_size)] +#[salsa::interned(debug, constructor=new_internal, heap_size=GenericContext::heap_size)] #[derive(PartialOrd, Ord)] pub struct GenericContext<'db> { #[returns(ref)] - pub(crate) variables: FxOrderSet>, + variables_inner: FxOrderMap, GenericContextTypeVarOptions>, } pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( @@ -158,7 +167,7 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S visitor: &V, ) { for bound_typevar in context.variables(db) { - visitor.visit_bound_type_var_type(db, *bound_typevar); + visitor.visit_bound_type_var_type(db, bound_typevar); } } @@ -166,6 +175,13 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S impl get_size2::GetSize for GenericContext<'_> {} impl<'db> GenericContext<'db> { + fn from_variables( + db: &'db dyn Db, + variables: impl IntoIterator, GenericContextTypeVarOptions)>, + ) -> Self { + Self::new_internal(db, variables.into_iter().collect::>()) + } + /// Creates a generic context from a list of PEP-695 type parameters. pub(crate) fn from_type_params( db: &'db dyn Db, @@ -185,21 +201,44 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, type_params: impl IntoIterator>, ) -> Self { - Self::new(db, type_params.into_iter().collect::>()) + Self::from_variables( + db, + type_params + .into_iter() + .map(|bound_typevar| (bound_typevar, GenericContextTypeVarOptions::default())), + ) + } + + /// Returns a copy of this generic context where we will promote literal types in any inferred + /// specializations. + pub(crate) fn promote_literals(self, db: &'db dyn Db) -> Self { + Self::from_variables( + db, + self.variables_inner(db) + .iter() + .map(|(bound_typevar, options)| (*bound_typevar, options.promote_literals())), + ) } /// Merge this generic context with another, returning a new generic context that /// contains type variables from both contexts. pub(crate) fn merge(self, db: &'db dyn Db, other: Self) -> Self { - Self::from_typevar_instances( + Self::from_variables( db, - self.variables(db) + self.variables_inner(db) .iter() - .chain(other.variables(db).iter()) - .copied(), + .chain(other.variables_inner(db).iter()) + .map(|(bound_typevar, options)| (*bound_typevar, *options)), ) } + pub(crate) fn variables( + self, + db: &'db dyn Db, + ) -> impl ExactSizeIterator> + Clone { + self.variables_inner(db).keys().copied() + } + fn variable_from_type_param( db: &'db dyn Db, index: &'db SemanticIndex<'db>, @@ -247,7 +286,7 @@ impl<'db> GenericContext<'db> { if variables.is_empty() { return None; } - Some(Self::new(db, variables)) + Some(Self::from_typevar_instances(db, variables)) } /// Creates a generic context from the legacy `TypeVar`s that appear in class's base class @@ -263,18 +302,17 @@ impl<'db> GenericContext<'db> { if variables.is_empty() { return None; } - Some(Self::new(db, variables)) + Some(Self::from_typevar_instances(db, variables)) } pub(crate) fn len(self, db: &'db dyn Db) -> usize { - self.variables(db).len() + self.variables_inner(db).len() } pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { let parameters = Parameters::new( self.variables(db) - .iter() - .map(|typevar| Self::parameter_from_typevar(db, *typevar)), + .map(|typevar| Self::parameter_from_typevar(db, typevar)), ); Signature::new(parameters, None) } @@ -309,8 +347,7 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, known_class: Option, ) -> Specialization<'db> { - let partial = - self.specialize_partial(db, std::iter::repeat_n(None, self.variables(db).len())); + let partial = self.specialize_partial(db, std::iter::repeat_n(None, self.len(db))); if known_class == Some(KnownClass::Tuple) { Specialization::new( db, @@ -332,31 +369,24 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, typevar_to_type: &impl Fn(BoundTypeVarInstance<'db>) -> Type<'db>, ) -> Specialization<'db> { - let types = self - .variables(db) - .iter() - .map(|typevar| typevar_to_type(*typevar)) - .collect(); + let types = self.variables(db).map(typevar_to_type).collect(); self.specialize(db, types) } pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> { - let types = vec![Type::unknown(); self.variables(db).len()]; + let types = vec![Type::unknown(); self.len(db)]; self.specialize(db, types.into()) } /// Returns a tuple type of the typevars introduced by this generic context. pub(crate) fn as_tuple(self, db: &'db dyn Db) -> Type<'db> { - Type::heterogeneous_tuple( - db, - self.variables(db) - .iter() - .map(|typevar| Type::TypeVar(*typevar)), - ) + Type::heterogeneous_tuple(db, self.variables(db).map(Type::TypeVar)) } pub(crate) fn is_subset_of(self, db: &'db dyn Db, other: GenericContext<'db>) -> bool { - self.variables(db).is_subset(other.variables(db)) + let other_variables = other.variables_inner(db); + self.variables(db) + .all(|bound_typevar| other_variables.contains_key(&bound_typevar)) } pub(crate) fn binds_typevar( @@ -365,9 +395,7 @@ impl<'db> GenericContext<'db> { typevar: TypeVarInstance<'db>, ) -> Option> { self.variables(db) - .iter() .find(|self_bound_typevar| self_bound_typevar.typevar(db) == typevar) - .copied() } /// Creates a specialization of this generic context. Panics if the length of `types` does not @@ -379,7 +407,7 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, types: Box<[Type<'db>]>, ) -> Specialization<'db> { - assert!(self.variables(db).len() == types.len()); + assert!(self.len(db) == types.len()); Specialization::new(db, self, types, None, None) } @@ -403,7 +431,7 @@ impl<'db> GenericContext<'db> { { let types = types.into_iter(); let variables = self.variables(db); - assert!(variables.len() == types.len()); + assert!(self.len(db) == types.len()); // Typevars can have other typevars as their default values, e.g. // @@ -442,14 +470,15 @@ impl<'db> GenericContext<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { let variables = self .variables(db) - .iter() .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)); Self::from_typevar_instances(db, variables) } - fn heap_size((variables,): &(FxOrderSet>,)) -> usize { - ruff_memory_usage::order_set_heap_size(variables) + fn heap_size( + (variables,): &(FxOrderMap, GenericContextTypeVarOptions>,), + ) -> usize { + ruff_memory_usage::order_map_heap_size(variables) } } @@ -661,6 +690,31 @@ fn has_relation_in_invariant_position<'db>( } impl<'db> Specialization<'db> { + /// Restricts this specialization to only include the typevars in a generic context. If the + /// specialization does not include all of those typevars, returns `None`. + pub(crate) fn restrict( + self, + db: &'db dyn Db, + generic_context: GenericContext<'db>, + ) -> Option { + let self_variables = self.generic_context(db).variables_inner(db); + let self_types = self.types(db); + let restricted_variables = generic_context.variables(db); + let restricted_types: Option> = restricted_variables + .map(|variable| { + let index = self_variables.get_index_of(&variable)?; + self_types.get(index).copied() + }) + .collect(); + Some(Self::new( + db, + generic_context, + restricted_types?, + self.materialization_kind(db), + None, + )) + } + /// Returns the tuple spec for a specialization of the `tuple` class. pub(crate) fn tuple(self, db: &'db dyn Db) -> Option<&'db TupleSpec<'db>> { self.tuple_inner(db).map(|tuple_type| tuple_type.tuple(db)) @@ -675,7 +729,7 @@ impl<'db> Specialization<'db> { ) -> Option> { let index = self .generic_context(db) - .variables(db) + .variables_inner(db) .get_index_of(&bound_typevar)?; self.types(db).get(index).copied() } @@ -813,7 +867,6 @@ impl<'db> Specialization<'db> { let types: Box<[_]> = self .generic_context(db) .variables(db) - .into_iter() .zip(self.types(db)) .map(|(bound_typevar, vartype)| { match bound_typevar.variance(db) { @@ -882,7 +935,7 @@ impl<'db> Specialization<'db> { let other_materialization_kind = other.materialization_kind(db); let mut result = ConstraintSet::from(true); - for ((bound_typevar, self_type), other_type) in (generic_context.variables(db).into_iter()) + for ((bound_typevar, self_type), other_type) in (generic_context.variables(db)) .zip(self.types(db)) .zip(other.types(db)) { @@ -933,7 +986,7 @@ impl<'db> Specialization<'db> { } let mut result = ConstraintSet::from(true); - for ((bound_typevar, self_type), other_type) in (generic_context.variables(db).into_iter()) + for ((bound_typevar, self_type), other_type) in (generic_context.variables(db)) .zip(self.types(db)) .zip(other.types(db)) { @@ -1005,7 +1058,7 @@ impl<'db> PartialSpecialization<'_, 'db> { ) -> Option> { let index = self .generic_context - .variables(db) + .variables_inner(db) .get_index_of(&bound_typevar)?; self.types.get(index).copied() } @@ -1027,10 +1080,18 @@ impl<'db> SpecializationBuilder<'db> { } pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> { - let types = generic_context - .variables(self.db) - .iter() - .map(|variable| self.types.get(variable).copied()); + let types = (generic_context.variables_inner(self.db).iter()).map(|(variable, options)| { + let mut ty = self.types.get(variable).copied(); + + // When inferring a specialization for a generic class typevar from a constructor call, + // promote any typevars that are inferred as a literal to the corresponding instance + // type. + if options.should_promote_literals { + ty = ty.map(|ty| ty.promote_literals(self.db)); + } + + ty + }); // TODO Infer the tuple spec for a tuple type generic_context.specialize_partial(self.db, types) } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 55c799a656..b4fe2b5a23 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -88,13 +88,12 @@ use crate::types::typed_dict::{ }; use crate::types::visitor::any_over_type; use crate::types::{ - BoundTypeVarInstance, CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, - DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, - MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, - Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, - TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, - UnionBuilder, UnionType, binding_type, todo_type, + CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, + IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, + MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, + SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, + TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, + TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -2141,10 +2140,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { deprecated, dataclass_transformer_params, ); - - let inherited_generic_context = None; - let function_literal = - FunctionLiteral::new(self.db(), overload_literal, inherited_generic_context); + let function_literal = FunctionLiteral::new(self.db(), overload_literal); let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(self.db(), function_literal, None, None)); @@ -5354,16 +5350,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { collection_class: KnownClass, ) -> Option> { // Extract the type variable `T` from `list[T]` in typeshed. - fn elt_tys( - collection_class: KnownClass, - db: &dyn Db, - ) -> Option<(ClassLiteral<'_>, &FxOrderSet>)> { - let class_literal = collection_class.try_to_class_literal(db)?; - let generic_context = class_literal.generic_context(db)?; - Some((class_literal, generic_context.variables(db))) - } + let elt_tys = |collection_class: KnownClass| { + let class_literal = collection_class.try_to_class_literal(self.db())?; + let generic_context = class_literal.generic_context(self.db())?; + Some((class_literal, generic_context.variables(self.db()))) + }; - let (class_literal, elt_tys) = elt_tys(collection_class, self.db()).unwrap_or_else(|| { + let (class_literal, elt_tys) = elt_tys(collection_class).unwrap_or_else(|| { let name = collection_class.name(self.db()); panic!("Typeshed should always have a `{name}` class in `builtins.pyi`") }); @@ -5382,9 +5375,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Note that we infer the annotated type _before_ the elements, to more closely match the // order of any unions as written in the type annotation. Some(annotated_elt_tys) => { - for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys, annotated_elt_tys) { + for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) { builder - .infer(Type::TypeVar(*elt_ty), *annotated_elt_ty) + .infer(Type::TypeVar(elt_ty), *annotated_elt_ty) .ok()?; } } @@ -5392,10 +5385,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // If a valid type annotation was not provided, avoid restricting the type of the collection // by unioning the inferred type with `Unknown`. None => { - for elt_ty in elt_tys { - builder - .infer(Type::TypeVar(*elt_ty), Type::unknown()) - .ok()?; + for elt_ty in elt_tys.clone() { + builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?; } } } @@ -5415,10 +5406,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { inferred_value_ty.known_specialization(KnownClass::Dict, self.db()) { for (elt_ty, inferred_elt_ty) in - iter::zip(elt_tys, specialization.types(self.db())) + iter::zip(elt_tys.clone(), specialization.types(self.db())) { builder - .infer(Type::TypeVar(*elt_ty), *inferred_elt_ty) + .infer(Type::TypeVar(elt_ty), *inferred_elt_ty) .ok()?; } } @@ -5427,7 +5418,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // The inferred type of each element acts as an additional constraint on `T`. - for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys, elt_tcxs.clone()) { + for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone()) + { let Some(inferred_elt_ty) = self.infer_optional_expression(elt, elt_tcx) else { continue; }; @@ -5436,9 +5428,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // unions for large nested list literals, which the constraint solver struggles with. let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db()); - builder - .infer(Type::TypeVar(*elt_ty), inferred_elt_ty) - .ok()?; + builder.infer(Type::TypeVar(elt_ty), inferred_elt_ty).ok()?; } } @@ -9012,7 +9002,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } }) .collect(); - typevars.map(|typevars| GenericContext::new(self.db(), typevars)) + typevars.map(|typevars| GenericContext::from_typevar_instances(self.db(), typevars)) } fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 46d0438c04..d6f3ee7374 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -102,12 +102,13 @@ impl<'db> CallableSignature<'db> { pub(crate) fn with_inherited_generic_context( &self, - inherited_generic_context: Option>, + db: &'db dyn Db, + inherited_generic_context: GenericContext<'db>, ) -> Self { Self::from_overloads(self.overloads.iter().map(|signature| { signature .clone() - .with_inherited_generic_context(inherited_generic_context) + .with_inherited_generic_context(db, inherited_generic_context) })) } @@ -301,11 +302,6 @@ pub struct Signature<'db> { /// The generic context for this overload, if it is generic. pub(crate) generic_context: Option>, - /// The inherited generic context, if this function is a class method being used to infer the - /// specialization of its generic class. If the method is itself generic, this is in addition - /// to its own generic context. - pub(crate) inherited_generic_context: Option>, - /// The original definition associated with this function, if available. /// This is useful for locating and extracting docstring information for the signature. pub(crate) definition: Option>, @@ -332,9 +328,6 @@ pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( if let Some(generic_context) = &signature.generic_context { walk_generic_context(db, *generic_context, visitor); } - if let Some(inherited_generic_context) = &signature.inherited_generic_context { - walk_generic_context(db, *inherited_generic_context, visitor); - } // By default we usually don't visit the type of the default value, // as it isn't relevant to most things for parameter in &signature.parameters { @@ -351,7 +344,6 @@ impl<'db> Signature<'db> { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option>) -> Self { Self { generic_context: None, - inherited_generic_context: None, definition: None, parameters, return_ty, @@ -365,7 +357,6 @@ impl<'db> Signature<'db> { ) -> Self { Self { generic_context, - inherited_generic_context: None, definition: None, parameters, return_ty, @@ -376,7 +367,6 @@ impl<'db> Signature<'db> { pub(crate) fn dynamic(signature_type: Type<'db>) -> Self { Signature { generic_context: None, - inherited_generic_context: None, definition: None, parameters: Parameters::gradual_form(), return_ty: Some(signature_type), @@ -389,7 +379,6 @@ impl<'db> Signature<'db> { let signature_type = todo_type!(reason); Signature { generic_context: None, - inherited_generic_context: None, definition: None, parameters: Parameters::todo(), return_ty: Some(signature_type), @@ -400,7 +389,6 @@ impl<'db> Signature<'db> { pub(super) fn from_function( db: &'db dyn Db, generic_context: Option>, - inherited_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, is_generator: bool, @@ -434,7 +422,6 @@ impl<'db> Signature<'db> { (Some(legacy_ctx), Some(ctx)) => { if legacy_ctx .variables(db) - .iter() .exactly_one() .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) { @@ -449,7 +436,6 @@ impl<'db> Signature<'db> { Self { generic_context: full_generic_context, - inherited_generic_context, definition: Some(definition), parameters, return_ty, @@ -468,9 +454,17 @@ impl<'db> Signature<'db> { pub(crate) fn with_inherited_generic_context( mut self, - inherited_generic_context: Option>, + db: &'db dyn Db, + inherited_generic_context: GenericContext<'db>, ) -> Self { - self.inherited_generic_context = inherited_generic_context; + match self.generic_context.as_mut() { + Some(generic_context) => { + *generic_context = generic_context.merge(db, inherited_generic_context); + } + None => { + self.generic_context = Some(inherited_generic_context); + } + } self } @@ -483,9 +477,6 @@ impl<'db> Signature<'db> { generic_context: self .generic_context .map(|ctx| ctx.normalized_impl(db, visitor)), - inherited_generic_context: self - .inherited_generic_context - .map(|ctx| ctx.normalized_impl(db, visitor)), // Discard the definition when normalizing, so that two equivalent signatures // with different `Definition`s share the same Salsa ID when normalized definition: None, @@ -516,7 +507,6 @@ impl<'db> Signature<'db> { generic_context: self .generic_context .map(|context| type_mapping.update_signature_generic_context(db, context)), - inherited_generic_context: self.inherited_generic_context, definition: self.definition, parameters: self .parameters @@ -571,7 +561,6 @@ impl<'db> Signature<'db> { } Self { generic_context: self.generic_context, - inherited_generic_context: self.inherited_generic_context, definition: self.definition, parameters, return_ty, @@ -1236,10 +1225,7 @@ impl<'db> Parameters<'db> { let method_has_self_in_generic_context = method.signature(db).overloads.iter().any(|s| { s.generic_context.is_some_and(|context| { - context - .variables(db) - .iter() - .any(|v| v.typevar(db).is_self(db)) + context.variables(db).any(|v| v.typevar(db).is_self(db)) }) }); @@ -1882,7 +1868,7 @@ mod tests { .literal(&db) .last_definition(&db); - let sig = func.signature(&db, None); + let sig = func.signature(&db); assert!(sig.return_ty.is_none()); assert_params(&sig, &[]); @@ -1907,7 +1893,7 @@ mod tests { .literal(&db) .last_definition(&db); - let sig = func.signature(&db, None); + let sig = func.signature(&db); assert_eq!(sig.return_ty.unwrap().display(&db).to_string(), "bytes"); assert_params( @@ -1959,7 +1945,7 @@ mod tests { .literal(&db) .last_definition(&db); - let sig = func.signature(&db, None); + let sig = func.signature(&db); let [ Parameter { @@ -1997,7 +1983,7 @@ mod tests { .literal(&db) .last_definition(&db); - let sig = func.signature(&db, None); + let sig = func.signature(&db); let [ Parameter { @@ -2035,7 +2021,7 @@ mod tests { .literal(&db) .last_definition(&db); - let sig = func.signature(&db, None); + let sig = func.signature(&db); let [ Parameter { @@ -2079,7 +2065,7 @@ mod tests { .literal(&db) .last_definition(&db); - let sig = func.signature(&db, None); + let sig = func.signature(&db); let [ Parameter { @@ -2116,7 +2102,7 @@ mod tests { let func = get_function_f(&db, "/src/a.py"); let overload = func.literal(&db).last_definition(&db); - let expected_sig = overload.signature(&db, None); + let expected_sig = overload.signature(&db); // With no decorators, internal and external signature are the same assert_eq!(