From 0a4dec03236ec38b48f78b07505619dfc97c4853 Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Fri, 18 Apr 2025 17:04:26 -0400 Subject: [PATCH] start pulling out enum --- crates/red_knot_python_semantic/src/types.rs | 146 ++++++++++++------ .../src/types/call/bind.rs | 31 ++-- .../src/types/class.rs | 10 +- .../src/types/context.rs | 4 +- 4 files changed, 135 insertions(+), 56 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index aeb0c67b8e..1b035f9eb2 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -5863,8 +5863,91 @@ impl<'db> IntoIterator for &'db FunctionSignature<'db> { } } +/// A callable type that represents a single Python function. +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Update)] +pub enum FunctionType<'db> { + /// A function literal in the Python AST + FunctionLiteral(FunctionLiteral<'db>), + + /// A function that has a specialization applied to its signature. + /// + /// (This does not necessarily mean that the function itself is generic — the methods of a + /// generic class, for instance, will have the class's specialization applied so that we + /// correctly substitute any class typevars that appear in the signature.) + Specialized(SpecializedFunction<'db>), + + /// A function that we treat as generic because it inherits a containing generic context. + /// + /// This is currently only used for the `__new__` and `__init__` methods of a generic class. + /// That lets us pretend those methods are generic, so that we can infer a class specialization + /// from the arguments to its constructor. + InheritedGenericContext(FunctionWithInheritedGenericContext<'db>), +} + +impl<'db> FunctionType<'db> { + fn function_literal(self, db: &'db dyn Db) -> FunctionLiteral<'db> { + match self { + FunctionType::FunctionLiteral(literal) => literal, + FunctionType::InheritedGenericContext(inherited) => inherited.function(db), + } + } + + pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool { + self.function_literal(db).decorators(db).contains(decorator) + } + + /// Convert the `FunctionType` into a [`Type::Callable`]. + pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { + Type::Callable(CallableType::from_overloads( + db, + self.signature(db).iter().cloned(), + )) + } + + /// Returns the [`FileRange`] of the function's name. + pub fn focus_range(self, db: &dyn Db) -> FileRange { + let body_scope = self.function_literal(db).body_scope(db); + FileRange::new( + body_scope.file(db), + body_scope.node(db).expect_function().name.range, + ) + } + + pub fn full_range(self, db: &dyn Db) -> FileRange { + let body_scope = self.function_literal(db).body_scope(db); + FileRange::new( + body_scope.file(db), + body_scope.node(db).expect_function().range, + ) + } + + pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { + let body_scope = self.function_literal(db).body_scope(db); + let index = semantic_index(db, body_scope.file(db)); + index.expect_single_definition(body_scope.node(db).expect_function()) + } + + /// Typed externally-visible signature for this function. + /// + /// This is the signature as seen by external callers, possibly modified by decorators and/or + /// overloaded. + /// + /// ## Why is this a salsa query? + /// + /// This is a salsa query to short-circuit the invalidation + /// when the function's AST node changes. + /// + /// Were this not a salsa query, then the calling query + /// would depend on the function's AST and rerun for every change in that file. + pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { + match self { + FunctionType::FunctionLiteral(literal) => literal.signature(db), + } + } +} + #[salsa::interned(debug)] -pub struct FunctionType<'db> { +pub struct FunctionLiteral<'db> { /// Name of the function at definition. #[return_ref] pub name: ast::name::Name, @@ -5888,40 +5971,11 @@ pub struct FunctionType<'db> { } #[salsa::tracked] -impl<'db> FunctionType<'db> { - pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool { +impl<'db> FunctionLiteral<'db> { + fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool { self.decorators(db).contains(decorator) } - /// Convert the `FunctionType` into a [`Type::Callable`]. - pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::from_overloads( - db, - self.signature(db).iter().cloned(), - )) - } - - /// Returns the [`FileRange`] of the function's name. - pub fn focus_range(self, db: &dyn Db) -> FileRange { - FileRange::new( - self.body_scope(db).file(db), - self.body_scope(db).node(db).expect_function().name.range, - ) - } - - pub fn full_range(self, db: &dyn Db) -> FileRange { - FileRange::new( - self.body_scope(db).file(db), - self.body_scope(db).node(db).expect_function().range, - ) - } - - pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { - let body_scope = self.body_scope(db); - let index = semantic_index(db, body_scope.file(db)); - index.expect_single_definition(body_scope.node(db).expect_function()) - } - /// Typed externally-visible signature for this function. /// /// This is the signature as seen by external callers, possibly modified by decorators and/or @@ -5934,8 +5988,8 @@ impl<'db> FunctionType<'db> { /// /// Were this not a salsa query, then the calling query /// would depend on the function's AST and rerun for every change in that file. - #[salsa::tracked(return_ref)] - pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { + #[salsa::tracked] + fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { let mut internal_signature = self.internal_signature(db); if let Some(specialization) = self.specialization(db) { @@ -6007,16 +6061,16 @@ impl<'db> FunctionType<'db> { self.known(db) == Some(known_function) } - fn with_generic_context(self, db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self { - Self::new( + fn with_generic_context( + self, + db: &'db dyn Db, + generic_context: GenericContext<'db>, + ) -> FunctionType<'db> { + FunctionType::InheritedGenericContext(FunctionWithInheritedGenericContext::new( db, - self.name(db).clone(), - self.known(db), - self.body_scope(db), - self.decorators(db), - Some(generic_context), - self.specialization(db), - ) + self, + generic_context, + )) } fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { @@ -6036,6 +6090,12 @@ impl<'db> FunctionType<'db> { } } +#[salsa::interned(debug)] +pub struct FunctionWithInheritedGenericContext<'db> { + function: FunctionLiteral<'db>, + generic_context: GenericContext<'db>, +} + /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might /// have special behavior. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, strum_macros::EnumString)] diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 7232a08f8d..f5c1f3f940 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -219,7 +219,8 @@ impl<'db> Bindings<'db> { match binding_type { Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { - if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { + let function_literal = function.function_literal(); + if function_literal.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { match overload.parameter_types() { [_, Some(owner)] => { overload.set_return_type(Type::BoundMethod(BoundMethodType::new( @@ -250,7 +251,9 @@ impl<'db> Bindings<'db> { if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] = overload.parameter_types() { - if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { + let function_literal = function.function_literal(); + if function_literal.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) + { match overload.parameter_types() { [_, _, Some(owner)] => { overload.set_return_type(Type::BoundMethod( @@ -298,7 +301,7 @@ impl<'db> Bindings<'db> { if property.getter(db).is_some_and(|getter| { getter .into_function_literal() - .is_some_and(|f| f.name(db) == "__name__") + .is_some_and(|f| f.function_literal().name(db) == "__name__") }) => { overload.set_return_type(Type::string_literal(db, type_alias.name(db))); @@ -307,7 +310,7 @@ impl<'db> Bindings<'db> { if property.getter(db).is_some_and(|getter| { getter .into_function_literal() - .is_some_and(|f| f.name(db) == "__name__") + .is_some_and(|f| f.function_literal().name(db) == "__name__") }) => { overload.set_return_type(Type::string_literal(db, type_var.name(db))); @@ -416,7 +419,12 @@ impl<'db> Bindings<'db> { Type::BoundMethod(bound_method) if bound_method.self_instance(db).is_property_instance() => { - match bound_method.function(db).name(db).as_str() { + match bound_method + .function(db) + .function_literal() + .name(db) + .as_str() + { "setter" => { if let [Some(_), Some(setter)] = overload.parameter_types() { let mut ty_property = bound_method.self_instance(db); @@ -456,7 +464,10 @@ impl<'db> Bindings<'db> { } } - Type::FunctionLiteral(function_type) => match function_type.known(db) { + Type::FunctionLiteral(function_type) => match function_type + .function_literal() + .known(db) + { Some(KnownFunction::IsEquivalentTo) => { if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral( @@ -1166,7 +1177,7 @@ impl<'db> CallableDescription<'db> { match callable_type { Type::FunctionLiteral(function) => Some(CallableDescription { kind: "function", - name: function.name(db), + name: function.function_literal().name(db), }), Type::ClassLiteral(class_type) => Some(CallableDescription { kind: "class", @@ -1174,12 +1185,12 @@ impl<'db> CallableDescription<'db> { }), Type::BoundMethod(bound_method) => Some(CallableDescription { kind: "bound method", - name: bound_method.function(db).name(db), + name: bound_method.function(db).function_literal().name(db), }), Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { Some(CallableDescription { kind: "method wrapper `__get__` of function", - name: function.name(db), + name: function.function_literal().name(db), }) } Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(_)) => { @@ -1304,7 +1315,7 @@ impl<'db> BindingError<'db> { ) -> Option<(Span, Span)> { match callable_ty { Type::FunctionLiteral(function) => { - let function_scope = function.body_scope(db); + let function_scope = function.function_literal().body_scope(db); let span = Span::from(function_scope.file(db)); let node = function_scope.node(db); if let Some(func_def) = node.as_function() { diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 2f57261650..a7581a6ddb 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -610,7 +610,11 @@ impl<'db> ClassLiteralType<'db> { self.decorators(db) .iter() .filter_map(|deco| deco.into_function_literal()) - .any(|decorator| decorator.is_known(db, KnownFunction::Final)) + .any(|decorator| { + decorator + .function_literal() + .is_known(db, KnownFunction::Final) + }) } /// Attempt to resolve the [method resolution order] ("MRO") for this class. @@ -951,7 +955,9 @@ impl<'db> ClassLiteralType<'db> { Some(_), "__new__" | "__init__", ) => Type::FunctionLiteral( - function.with_generic_context(db, origin.generic_context(db)), + function + .function_literal() + .with_generic_context(db, origin.generic_context(db)), ), _ => ty, } diff --git a/crates/red_knot_python_semantic/src/types/context.rs b/crates/red_knot_python_semantic/src/types/context.rs index 5f92be04ab..f85bdc9563 100644 --- a/crates/red_knot_python_semantic/src/types/context.rs +++ b/crates/red_knot_python_semantic/src/types/context.rs @@ -169,7 +169,9 @@ impl<'db> InferContext<'db> { // Iterate over all functions and test if any is decorated with `@no_type_check`. function_scope_tys.any(|function_ty| { - function_ty.has_known_decorator(self.db, FunctionDecorators::NO_TYPE_CHECK) + function_ty + .function_literal() + .has_known_decorator(self.db, FunctionDecorators::NO_TYPE_CHECK) }) } InNoTypeCheck::Yes => true,