diff --git a/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md b/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md index d22660fe81..8630731764 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/generics/classes.md @@ -13,8 +13,6 @@ class C[T]: ... A class that inherits from a generic class, and fills its type parameters with typevars, is generic: ```py -# TODO: no error -# error: [non-subscriptable] class D[U](C[U]): ... ``` @@ -22,8 +20,6 @@ A class that inherits from a generic class, but fills its type parameters with c _not_ generic: ```py -# TODO: no error -# error: [non-subscriptable] class E(C[int]): ... ``` @@ -57,7 +53,7 @@ class D(C[T]): ... (Examples `E` and `F` from above do not have analogues in the legacy syntax.) -## Inferring generic class parameters +## Specializing generic classes explicitly The type parameter can be specified explicitly: @@ -65,25 +61,77 @@ The type parameter can be specified explicitly: class C[T]: x: T -# TODO: no error -# TODO: revealed: C[int] -# error: [non-subscriptable] -reveal_type(C[int]()) # revealed: C +reveal_type(C[int]()) # revealed: C[int] ``` +The specialization must match the generic types: + +```py +# error: [too-many-positional-arguments] "Too many positional arguments to class `C`: expected 1, got 2" +reveal_type(C[int, int]()) # revealed: Unknown +``` + +If the type variable has an upper bound, the specialized type must satisfy that bound: + +```py +class Bounded[T: int]: ... +class BoundedByUnion[T: int | str]: ... +class IntSubclass(int): ... + +reveal_type(Bounded[int]()) # revealed: Bounded[int] +reveal_type(Bounded[IntSubclass]()) # revealed: Bounded[IntSubclass] + +# error: [invalid-argument-type] "Object of type `str` cannot be assigned to parameter 1 (`T`) of class `Bounded`; expected type `int`" +reveal_type(Bounded[str]()) # revealed: Unknown + +# error: [invalid-argument-type] "Object of type `int | str` cannot be assigned to parameter 1 (`T`) of class `Bounded`; expected type `int`" +reveal_type(Bounded[int | str]()) # revealed: Unknown + +reveal_type(BoundedByUnion[int]()) # revealed: BoundedByUnion[int] +reveal_type(BoundedByUnion[IntSubclass]()) # revealed: BoundedByUnion[IntSubclass] +reveal_type(BoundedByUnion[str]()) # revealed: BoundedByUnion[str] +reveal_type(BoundedByUnion[int | str]()) # revealed: BoundedByUnion[int | str] +``` + +If the type variable is constrained, the specialized type must satisfy those constraints: + +```py +class Constrained[T: (int, str)]: ... + +reveal_type(Constrained[int]()) # revealed: Constrained[int] + +# TODO: error: [invalid-argument-type] +# TODO: revealed: Constrained[Unknown] +reveal_type(Constrained[IntSubclass]()) # revealed: Constrained[IntSubclass] + +reveal_type(Constrained[str]()) # revealed: Constrained[str] + +# TODO: error: [invalid-argument-type] +# TODO: revealed: Unknown +reveal_type(Constrained[int | str]()) # revealed: Constrained[int | str] + +# error: [invalid-argument-type] "Object of type `object` cannot be assigned to parameter 1 (`T`) of class `Constrained`; expected type `int | str`" +reveal_type(Constrained[object]()) # revealed: Unknown +``` + +## Inferring generic class parameters + We can infer the type parameter from a type context: ```py +class C[T]: + x: T + c: C[int] = C() # TODO: revealed: C[int] -reveal_type(c) # revealed: C +reveal_type(c) # revealed: C[Unknown] ``` The typevars of a fully specialized generic class should no longer be visible: ```py # TODO: revealed: int -reveal_type(c.x) # revealed: T +reveal_type(c.x) # revealed: Unknown ``` If the type parameter is not specified explicitly, and there are no constraints that let us infer a @@ -92,15 +140,13 @@ specific type, we infer the typevar's default type: ```py class D[T = int]: ... -# TODO: revealed: D[int] -reveal_type(D()) # revealed: D +reveal_type(D()) # revealed: D[int] ``` If a typevar does not provide a default, we use `Unknown`: ```py -# TODO: revealed: C[Unknown] -reveal_type(C()) # revealed: C +reveal_type(C()) # revealed: C[Unknown] ``` If the type of a constructor parameter is a class typevar, we can use that to infer the type @@ -111,17 +157,14 @@ class E[T]: def __init__(self, x: T) -> None: ... # TODO: revealed: E[int] or E[Literal[1]] -# TODO should not emit an error -# error: [invalid-argument-type] "Object of type `Literal[1]` cannot be assigned to parameter 2 (`x`) of bound method `__init__`; expected type `T`" -reveal_type(E(1)) # revealed: E +reveal_type(E(1)) # revealed: E[Unknown] ``` The types inferred from a type context and from a constructor parameter must be consistent with each other: ```py -# TODO: the error should not leak the `T` typevar and should mention `E[int]` -# error: [invalid-argument-type] "Object of type `Literal["five"]` cannot be assigned to parameter 2 (`x`) of bound method `__init__`; expected type `T`" +# TODO: error: [invalid-argument-type] wrong_innards: E[int] = E("five") ``` @@ -134,17 +177,33 @@ propagate through: class Base[T]: x: T | None = None -# TODO: no error -# error: [non-subscriptable] class Sub[U](Base[U]): ... +reveal_type(Base[int].x) # revealed: int | None +reveal_type(Sub[int].x) # revealed: int | None +``` + +## Generic methods + +Generic classes can contain methods that are themselves generic. The generic methods can refer to +the typevars of the enclosing generic class, and introduce new (distinct) typevars that are only in +scope for the method. + +```py +class C[T]: + def method[U](self, u: U) -> U: + return u + # error: [unresolved-reference] + def cannot_use_outside_of_method(self, u: U): ... + + # TODO: error + def cannot_shadow_class_typevar[T](self, t: T): ... + +c: C[int] = C[int]() # TODO: no error -# TODO: revealed: int | None -# error: [non-subscriptable] -reveal_type(Base[int].x) # revealed: T | None -# TODO: revealed: int | None -# error: [non-subscriptable] -reveal_type(Sub[int].x) # revealed: T | None +# TODO: revealed: str or Literal["string"] +# error: [invalid-argument-type] +reveal_type(c.method("string")) # revealed: U ``` ## Cyclic class definition @@ -158,8 +217,6 @@ Here, `Sub` is not a generic class, since it fills its superclass's type paramet ```pyi class Base[T]: ... -# TODO: no error -# error: [non-subscriptable] class Sub(Base[Sub]): ... reveal_type(Sub) # revealed: Literal[Sub] @@ -171,9 +228,6 @@ A similar case can work in a non-stub file, if forward references are stringifie ```py class Base[T]: ... - -# TODO: no error -# error: [non-subscriptable] class Sub(Base["Sub"]): ... reveal_type(Sub) # revealed: Literal[Sub] @@ -186,8 +240,6 @@ In a non-stub file, without stringified forward references, this raises a `NameE ```py class Base[T]: ... -# TODO: the unresolved-reference error is correct, the non-subscriptable is not -# error: [non-subscriptable] # error: [unresolved-reference] class Sub(Base[Sub]): ... ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md b/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md index 0d21a82294..2884b8d1b1 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md +++ b/crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md @@ -82,18 +82,51 @@ class C[T]: def m2(self, x: T) -> T: return x -c: C[int] = C() -# TODO: no error -# error: [invalid-argument-type] +c: C[int] = C[int]() c.m1(1) -# TODO: no error -# error: [invalid-argument-type] c.m2(1) -# TODO: expected type `int` -# error: [invalid-argument-type] "Object of type `Literal["string"]` cannot be assigned to parameter 2 (`x`) of bound method `m2`; expected type `T`" +# error: [invalid-argument-type] "Object of type `Literal["string"]` cannot be assigned to parameter 2 (`x`) of bound method `m2`; expected type `int`" c.m2("string") ``` +## Functions on generic classes are descriptors + +This repeats the tests in the [Functions as descriptors](./call/methods.md) test suite, but on a +generic class. This ensures that we are carrying any specializations through the entirety of the +descriptor protocol, which is how `self` parameters are bound to instance methods. + +```py +from inspect import getattr_static + +class C[T]: + def f(self, x: T) -> str: + return "a" + +reveal_type(getattr_static(C[int], "f")) # revealed: Literal[f[int]] +reveal_type(getattr_static(C[int], "f").__get__) # revealed: +reveal_type(getattr_static(C[int], "f").__get__(None, C[int])) # revealed: Literal[f[int]] +# revealed: +reveal_type(getattr_static(C[int], "f").__get__(C[int](), C[int])) + +reveal_type(C[int].f) # revealed: Literal[f[int]] +reveal_type(C[int]().f) # revealed: + +bound_method = C[int]().f +reveal_type(bound_method.__self__) # revealed: C[int] +reveal_type(bound_method.__func__) # revealed: Literal[f[int]] + +reveal_type(C[int]().f(1)) # revealed: str +reveal_type(bound_method(1)) # revealed: str + +C[int].f(1) # error: [missing-argument] +reveal_type(C[int].f(C[int](), 1)) # revealed: str + +class D[U](C[U]): + pass + +reveal_type(D[int]().f) # revealed: +``` + ## Methods can mention other typevars > A type variable used in a method that does not match any of the variables that parameterize the @@ -127,7 +160,6 @@ c: C[int] = C() # TODO: no errors # TODO: revealed: str # error: [invalid-argument-type] -# error: [invalid-argument-type] reveal_type(c.m(1, "string")) # revealed: S ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md index 0b7f0f49b6..e77a6152e7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md @@ -109,6 +109,24 @@ def _(x: A | B): reveal_type(x) # revealed: A ``` +## Narrowing for generic classes + +Note that `type` returns the runtime class of an object, which does _not_ include specializations in +the case of a generic class. (The typevars are erased.) That means we cannot narrow the type to the +specialization that we compare with; we must narrow to an unknown specialization of the generic +class. + +```py +class A[T = int]: ... +class B: ... + +def _[T](x: A | B): + if type(x) is A[str]: + reveal_type(x) # revealed: A[int] & A[Unknown] | B & A[Unknown] + else: + reveal_type(x) # revealed: A[int] | B +``` + ## Limitations ```py diff --git a/crates/red_knot_python_semantic/resources/mdtest/stubs/class.md b/crates/red_knot_python_semantic/resources/mdtest/stubs/class.md index e5d4956db9..233fa3f6f2 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/stubs/class.md +++ b/crates/red_knot_python_semantic/resources/mdtest/stubs/class.md @@ -8,13 +8,10 @@ In type stubs, classes can reference themselves in their base class definitions. ```pyi class Foo[T]: ... -# TODO: actually is subscriptable -# error: [non-subscriptable] class Bar(Foo[Bar]): ... reveal_type(Bar) # revealed: Literal[Bar] -# TODO: Instead of `Literal[Foo]`, we might eventually want to show a type that involves the type parameter. -reveal_type(Bar.__mro__) # revealed: tuple[Literal[Bar], Literal[Foo], Literal[object]] +reveal_type(Bar.__mro__) # revealed: tuple[Literal[Bar], Literal[Foo[Bar]], Literal[object]] ``` ## Access to attributes declared in stubs diff --git a/crates/red_knot_python_semantic/src/symbol.rs b/crates/red_knot_python_semantic/src/symbol.rs index 24f2f5d77b..437f2dd3c2 100644 --- a/crates/red_knot_python_semantic/src/symbol.rs +++ b/crates/red_knot_python_semantic/src/symbol.rs @@ -425,6 +425,17 @@ impl<'db> SymbolAndQualifiers<'db> { self.qualifiers.contains(TypeQualifiers::CLASS_VAR) } + #[must_use] + pub(crate) fn map_type( + self, + f: impl FnOnce(Type<'db>) -> Type<'db>, + ) -> SymbolAndQualifiers<'db> { + SymbolAndQualifiers { + symbol: self.symbol.map_type(f), + qualifiers: self.qualifiers, + } + } + /// Transform symbol and qualifiers into a [`LookupResult`], /// a [`Result`] type in which the `Ok` variant represents a definitely bound symbol /// and the `Err` variant represents a symbol that is either definitely or possibly unbound. diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index e04b4bf744..3197f7fa9a 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -35,12 +35,16 @@ use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers}; use crate::types::call::{Bindings, CallArgumentTypes}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; +use crate::types::generics::Specialization; use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::{Db, FxOrderSet, Module, Program}; -pub(crate) use class::{Class, ClassLiteralType, InstanceType, KnownClass, KnownInstanceType}; +pub(crate) use class::{ + Class, ClassLiteralType, ClassType, GenericAlias, GenericClass, InstanceType, KnownClass, + KnownInstanceType, NonGenericClass, +}; mod builder; mod call; @@ -49,6 +53,7 @@ mod class_base; mod context; mod diagnostic; mod display; +mod generics; mod infer; mod mro; mod narrow; @@ -276,6 +281,18 @@ pub struct PropertyInstanceType<'db> { setter: Option>, } +impl<'db> PropertyInstanceType<'db> { + fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + let getter = self + .getter(db) + .map(|ty| ty.apply_specialization(db, specialization)); + let setter = self + .setter(db) + .map(|ty| ty.apply_specialization(db, specialization)); + Self::new(db, getter, setter) + } +} + /// Representation of a type: a set of possible values at runtime. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] pub enum Type<'db> { @@ -319,7 +336,9 @@ pub enum Type<'db> { ModuleLiteral(ModuleLiteralType<'db>), /// A specific class object ClassLiteral(ClassLiteralType<'db>), - // The set of all class objects that are subclasses of the given class (C), spelled `type[C]`. + /// A specialization of a generic class + GenericAlias(GenericAlias<'db>), + /// The set of all class objects that are subclasses of the given class (C), spelled `type[C]`. SubclassOf(SubclassOfType<'db>), /// The set of Python objects with the given class in their __class__'s method resolution order Instance(InstanceType<'db>), @@ -382,20 +401,17 @@ impl<'db> Type<'db> { fn is_none(&self, db: &'db dyn Db) -> bool { self.into_instance() - .is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType)) + .is_some_and(|instance| instance.class.is_known(db, KnownClass::NoneType)) } pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool { - self.into_instance().is_some_and(|instance| { - instance - .class() - .is_known(db, KnownClass::NotImplementedType) - }) + self.into_instance() + .is_some_and(|instance| instance.class.is_known(db, KnownClass::NotImplementedType)) } pub fn is_object(&self, db: &'db dyn Db) -> bool { self.into_instance() - .is_some_and(|instance| instance.class().is_object(db)) + .is_some_and(|instance| instance.class.is_object(db)) } pub const fn is_todo(&self) -> bool { @@ -426,6 +442,12 @@ impl<'db> Type<'db> { | Self::WrapperDescriptor(_) | Self::MethodWrapper(_) => false, + Self::GenericAlias(generic) => generic + .specialization(db) + .types(db) + .iter() + .any(|ty| ty.contains_todo(db)), + Self::Callable(callable) => { let signature = callable.signature(db); signature.parameters().iter().any(|param| { @@ -467,10 +489,6 @@ impl<'db> Type<'db> { } } - pub const fn class_literal(class: Class<'db>) -> Self { - Self::ClassLiteral(ClassLiteralType { class }) - } - pub const fn into_class_literal(self) -> Option> { match self { Type::ClassLiteral(class_type) => Some(class_type), @@ -492,6 +510,29 @@ impl<'db> Type<'db> { matches!(self, Type::ClassLiteral(..)) } + pub const fn into_class_type(self) -> Option> { + match self { + Type::ClassLiteral(ClassLiteralType::NonGeneric(non_generic)) => { + Some(ClassType::NonGeneric(non_generic)) + } + Type::GenericAlias(alias) => Some(ClassType::Generic(alias)), + _ => None, + } + } + + #[track_caller] + pub fn expect_class_type(self) -> ClassType<'db> { + self.into_class_type() + .expect("Expected a Type::GenericAlias or non-generic Type::ClassLiteral variant") + } + + pub const fn is_class_type(&self) -> bool { + matches!( + self, + Type::ClassLiteral(ClassLiteralType::NonGeneric(_)) | Type::GenericAlias(_) + ) + } + pub const fn is_instance(&self) -> bool { matches!(self, Type::Instance(..)) } @@ -631,7 +672,7 @@ impl<'db> Type<'db> { matches!(self, Type::LiteralString) } - pub const fn instance(class: Class<'db>) -> Self { + pub const fn instance(class: ClassType<'db>) -> Self { Self::Instance(InstanceType { class }) } @@ -693,6 +734,10 @@ impl<'db> Type<'db> { | Type::KnownInstance(_) | Type::IntLiteral(_) | Type::SubclassOf(_) => self, + Type::GenericAlias(generic) => { + let specialization = generic.specialization(db).normalized(db); + Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization)) + } Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) { Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { Type::TypeVar(TypeVarInstance::new( @@ -932,13 +977,16 @@ impl<'db> Type<'db> { // `Literal[]` is a subtype of `type[B]` if `C` is a subclass of `B`, // since `type[B]` describes all possible runtime subclasses of the class object `B`. - ( - Type::ClassLiteral(ClassLiteralType { class }), - Type::SubclassOf(target_subclass_ty), - ) => target_subclass_ty + (Type::ClassLiteral(class), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty .subclass_of() .into_class() - .is_some_and(|target_class| class.is_subclass_of(db, target_class)), + .is_some_and(|target_class| class.is_subclass_of(db, None, target_class)), + (Type::GenericAlias(alias), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty + .subclass_of() + .into_class() + .is_some_and(|target_class| { + ClassType::from(alias).is_subclass_of(db, target_class) + }), // This branch asks: given two types `type[T]` and `type[S]`, is `type[T]` a subtype of `type[S]`? (Type::SubclassOf(self_subclass_ty), Type::SubclassOf(target_subclass_ty)) => { @@ -948,9 +996,12 @@ impl<'db> Type<'db> { // `Literal[str]` is a subtype of `type` because the `str` class object is an instance of its metaclass `type`. // `Literal[abc.ABC]` is a subtype of `abc.ABCMeta` because the `abc.ABC` class object // is an instance of its metaclass `abc.ABCMeta`. - (Type::ClassLiteral(ClassLiteralType { class }), _) => { + (Type::ClassLiteral(class), _) => { class.metaclass_instance_type(db).is_subtype_of(db, target) } + (Type::GenericAlias(alias), _) => ClassType::from(alias) + .metaclass_instance_type(db) + .is_subtype_of(db, target), // `type[str]` (== `SubclassOf("str")` in red-knot) describes all possible runtime subclasses // of the class object `str`. It is a subtype of `type` (== `Instance("type")`) because `str` @@ -1141,11 +1192,10 @@ impl<'db> Type<'db> { // Every class literal type is also assignable to `type[Any]`, because the class // literal type for a class `C` is a subtype of `type[C]`, and `type[C]` is assignable // to `type[Any]`. - (Type::ClassLiteral(_) | Type::SubclassOf(_), Type::SubclassOf(target_subclass_of)) - if target_subclass_of.is_dynamic() => - { - true - } + ( + Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::SubclassOf(_), + Type::SubclassOf(target_subclass_of), + ) if target_subclass_of.is_dynamic() => true, // `type[Any]` is assignable to any type that `type[object]` is assignable to, because // `type[Any]` can materialize to `type[object]`. @@ -1386,6 +1436,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(..) | Type::ModuleLiteral(..) | Type::ClassLiteral(..) + | Type::GenericAlias(..) | Type::KnownInstance(..)), right @ (Type::BooleanLiteral(..) | Type::IntLiteral(..) @@ -1398,6 +1449,7 @@ impl<'db> Type<'db> { | Type::WrapperDescriptor(..) | Type::ModuleLiteral(..) | Type::ClassLiteral(..) + | Type::GenericAlias(..) | Type::KnownInstance(..)), ) => left != right, @@ -1406,6 +1458,7 @@ impl<'db> Type<'db> { ( Type::Tuple(..), Type::ClassLiteral(..) + | Type::GenericAlias(..) | Type::ModuleLiteral(..) | Type::BooleanLiteral(..) | Type::BytesLiteral(..) @@ -1420,6 +1473,7 @@ impl<'db> Type<'db> { ) | ( Type::ClassLiteral(..) + | Type::GenericAlias(..) | Type::ModuleLiteral(..) | Type::BooleanLiteral(..) | Type::BytesLiteral(..) @@ -1434,17 +1488,23 @@ impl<'db> Type<'db> { Type::Tuple(..), ) => true, - ( - Type::SubclassOf(subclass_of_ty), - Type::ClassLiteral(ClassLiteralType { class: class_b }), - ) - | ( - Type::ClassLiteral(ClassLiteralType { class: class_b }), - Type::SubclassOf(subclass_of_ty), - ) => match subclass_of_ty.subclass_of() { - ClassBase::Dynamic(_) => false, - ClassBase::Class(class_a) => !class_b.is_subclass_of(db, class_a), - }, + (Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b)) + | (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => { + match subclass_of_ty.subclass_of() { + ClassBase::Dynamic(_) => false, + ClassBase::Class(class_a) => !class_b.is_subclass_of(db, None, class_a), + } + } + + (Type::SubclassOf(subclass_of_ty), Type::GenericAlias(alias_b)) + | (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => { + match subclass_of_ty.subclass_of() { + ClassBase::Dynamic(_) => false, + ClassBase::Class(class_a) => { + !ClassType::from(alias_b).is_subclass_of(db, class_a) + } + } + } ( Type::SubclassOf(_), @@ -1561,12 +1621,14 @@ impl<'db> Type<'db> { // A class-literal type `X` is always disjoint from an instance type `Y`, // unless the type expressing "all instances of `Z`" is a subtype of of `Y`, // where `Z` is `X`'s metaclass. - (Type::ClassLiteral(ClassLiteralType { class }), instance @ Type::Instance(_)) - | (instance @ Type::Instance(_), Type::ClassLiteral(ClassLiteralType { class })) => { - !class - .metaclass_instance_type(db) - .is_subtype_of(db, instance) - } + (Type::ClassLiteral(class), instance @ Type::Instance(_)) + | (instance @ Type::Instance(_), Type::ClassLiteral(class)) => !class + .metaclass_instance_type(db) + .is_subtype_of(db, instance), + (Type::GenericAlias(alias), instance @ Type::Instance(_)) + | (instance @ Type::Instance(_), Type::GenericAlias(alias)) => !ClassType::from(alias) + .metaclass_instance_type(db) + .is_subtype_of(db, instance), (Type::FunctionLiteral(..), Type::Instance(InstanceType { class })) | (Type::Instance(InstanceType { class }), Type::FunctionLiteral(..)) => { @@ -1692,7 +1754,7 @@ impl<'db> Type<'db> { }, Type::SubclassOf(subclass_of_ty) => subclass_of_ty.is_fully_static(), - Type::ClassLiteral(_) | Type::Instance(_) => { + Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::Instance(_) => { // TODO: Ideally, we would iterate over the MRO of the class, check if all // bases are fully static, and only return `true` if that is the case. // @@ -1763,6 +1825,7 @@ impl<'db> Type<'db> { | Type::FunctionLiteral(..) | Type::WrapperDescriptor(..) | Type::ClassLiteral(..) + | Type::GenericAlias(..) | Type::ModuleLiteral(..) | Type::KnownInstance(..) => true, Type::Callable(_) => { @@ -1828,6 +1891,7 @@ impl<'db> Type<'db> { | Type::MethodWrapper(_) | Type::ModuleLiteral(..) | Type::ClassLiteral(..) + | Type::GenericAlias(..) | Type::IntLiteral(..) | Type::BooleanLiteral(..) | Type::StringLiteral(..) @@ -1912,7 +1976,7 @@ impl<'db> Type<'db> { Type::Dynamic(_) | Type::Never => Some(Symbol::bound(self).into()), - Type::ClassLiteral(class_literal @ ClassLiteralType { class }) => { + Type::ClassLiteral(class) => { match (class.known(db), name) { (Some(KnownClass::FunctionType), "__get__") => Some( Symbol::bound(Type::WrapperDescriptor( @@ -1956,10 +2020,14 @@ impl<'db> Type<'db> { "__get__" | "__set__" | "__delete__", ) => Some(Symbol::Unbound.into()), - _ => Some(class_literal.class_member(db, name, policy)), + _ => Some(class.class_member(db, None, name, policy)), } } + Type::GenericAlias(alias) => { + Some(ClassType::from(*alias).class_member(db, name, policy)) + } + Type::SubclassOf(subclass_of) if name == "__get__" && matches!( @@ -2126,7 +2194,9 @@ impl<'db> Type<'db> { // a `__dict__` that is filled with class level attributes. Modeling this is currently not // required, as `instance_member` is only called for instance-like types through `member`, // but we might want to add this in the future. - Type::ClassLiteral(_) | Type::SubclassOf(_) => Symbol::Unbound.into(), + Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::SubclassOf(_) => { + Symbol::Unbound.into() + } } } @@ -2439,7 +2509,7 @@ impl<'db> Type<'db> { ) .into(), - Type::ClassLiteral(ClassLiteralType { class }) + Type::ClassLiteral(class) if name == "__get__" && class.is_known(db, KnownClass::FunctionType) => { Symbol::bound(Type::WrapperDescriptor( @@ -2447,7 +2517,7 @@ impl<'db> Type<'db> { )) .into() } - Type::ClassLiteral(ClassLiteralType { class }) + Type::ClassLiteral(class) if name == "__get__" && class.is_known(db, KnownClass::Property) => { Symbol::bound(Type::WrapperDescriptor( @@ -2455,7 +2525,7 @@ impl<'db> Type<'db> { )) .into() } - Type::ClassLiteral(ClassLiteralType { class }) + Type::ClassLiteral(class) if name == "__set__" && class.is_known(db, KnownClass::Property) => { Symbol::bound(Type::WrapperDescriptor( @@ -2599,8 +2669,8 @@ impl<'db> Type<'db> { } } - Type::ClassLiteral(..) | Type::SubclassOf(..) => { - let class_attr_plain = self.find_name_in_mro_with_policy(db, name_str, policy).expect( + Type::ClassLiteral(..) | Type::GenericAlias(..) | Type::SubclassOf(..) => { + let class_attr_plain = self.find_name_in_mro_with_policy(db, name_str,policy).expect( "Calling `find_name_in_mro` on class literals and subclass-of types should always return `Some`", ); @@ -2785,14 +2855,17 @@ impl<'db> Type<'db> { Type::AlwaysFalsy => Truthiness::AlwaysFalse, - Type::ClassLiteral(ClassLiteralType { class }) => class + Type::ClassLiteral(class) => class + .metaclass_instance_type(db) + .try_bool_impl(db, allow_short_circuit)?, + Type::GenericAlias(alias) => ClassType::from(*alias) .metaclass_instance_type(db) .try_bool_impl(db, allow_short_circuit)?, Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() { ClassBase::Dynamic(_) => Truthiness::Ambiguous, ClassBase::Class(class) => { - Type::class_literal(class).try_bool_impl(db, allow_short_circuit)? + Type::from(class).try_bool_impl(db, allow_short_circuit)? } }, @@ -3141,7 +3214,7 @@ impl<'db> Type<'db> { )), }, - Type::ClassLiteral(ClassLiteralType { class }) => match class.known(db) { + Type::ClassLiteral(class) => match class.known(db) { // TODO: Ideally we'd use `try_call_constructor` for all constructor calls. // Currently we don't for a few special known types, either because their // constructors are defined with overloads, or because we want to special case @@ -3345,13 +3418,23 @@ impl<'db> Type<'db> { } }, + Type::GenericAlias(_) => { + // TODO annotated return type on `__new__` or metaclass `__call__` + // TODO check call vs signatures of `__new__` and/or `__init__` + let signature = CallableSignature::single( + self, + Signature::new(Parameters::gradual_form(), self.to_instance(db)), + ); + Signatures::single(signature) + } + Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() { ClassBase::Dynamic(dynamic_type) => Type::Dynamic(dynamic_type).signatures(db), // Most type[] constructor calls are handled by `try_call_constructor` and not via // getting the signature here. This signature can still be used in some cases (e.g. // evaluating callable subtyping). TODO improve this definition (intersection of // `__new__` and `__init__` signatures? and respect metaclass `__call__`). - ClassBase::Class(class) => Type::class_literal(class).signatures(db), + ClassBase::Class(class) => Type::from(class).signatures(db), }, Type::Instance(_) => { @@ -3729,7 +3812,8 @@ impl<'db> Type<'db> { pub fn to_instance(&self, db: &'db dyn Db) -> Option> { match self { Type::Dynamic(_) | Type::Never => Some(*self), - Type::ClassLiteral(ClassLiteralType { class }) => Some(Type::instance(*class)), + Type::ClassLiteral(class) => Some(Type::instance(class.default_specialization(db))), + Type::GenericAlias(alias) => Some(Type::instance(ClassType::from(*alias))), Type::SubclassOf(subclass_of_ty) => Some(subclass_of_ty.to_instance()), Type::Union(union) => { let mut builder = UnionBuilder::new(db); @@ -3774,7 +3858,7 @@ impl<'db> Type<'db> { match self { // Special cases for `float` and `complex` // https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex - Type::ClassLiteral(ClassLiteralType { class }) => { + Type::ClassLiteral(class) => { let ty = match class.known(db) { Some(KnownClass::Any) => Type::any(), Some(KnownClass::Complex) => UnionType::from_elements( @@ -3792,10 +3876,11 @@ impl<'db> Type<'db> { KnownClass::Float.to_instance(db), ], ), - _ => Type::instance(*class), + _ => Type::instance(class.default_specialization(db)), }; Ok(ty) } + Type::GenericAlias(alias) => Ok(Type::instance(ClassType::from(*alias))), Type::SubclassOf(_) | Type::BooleanLiteral(_) @@ -4033,7 +4118,8 @@ impl<'db> Type<'db> { } }, - Type::ClassLiteral(ClassLiteralType { class }) => class.metaclass(db), + Type::ClassLiteral(class) => class.metaclass(db), + Type::GenericAlias(alias) => ClassType::from(*alias).metaclass(db), Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() { ClassBase::Dynamic(_) => *self, ClassBase::Class(class) => SubclassOfType::from( @@ -4055,6 +4141,121 @@ impl<'db> Type<'db> { } } + /// Applies a specialization to this type, replacing any typevars with the types that they are + /// specialized to. + /// + /// Note that this does not specialize generic classes, functions, or type aliases! That is a + /// different operation that is performed explicitly (via a subscript operation), or implicitly + /// via a call to the generic object. + #[must_use] + #[salsa::tracked] + pub fn apply_specialization( + self, + db: &'db dyn Db, + specialization: Specialization<'db>, + ) -> Type<'db> { + match self { + Type::TypeVar(typevar) => specialization.get(db, typevar).unwrap_or(self), + + Type::FunctionLiteral(function) => { + Type::FunctionLiteral(function.apply_specialization(db, specialization)) + } + + // Note that we don't need to apply the specialization to `self_instance`, since it + // must either be a non-generic class literal (which cannot have any typevars to + // specialize) or a generic alias (which has already been fully specialized). For a + // generic alias, the specialization being applied here must be for some _other_ + // generic context nested within the generic alias's class literal, which the generic + // alias's context cannot refer to. (The _method_ does need to be specialized, since it + // might be a nested generic method, whose generic context is what is now being + // specialized.) + Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new( + db, + method.function(db).apply_specialization(db, specialization), + method.self_instance(db), + )), + + Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { + Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet( + function.apply_specialization(db, specialization), + )) + } + + Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => { + Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet( + property.apply_specialization(db, specialization), + )) + } + + Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)) => { + Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet( + property.apply_specialization(db, specialization), + )) + } + + Type::Callable(callable) => { + Type::Callable(callable.apply_specialization(db, specialization)) + } + + Type::GenericAlias(generic) => { + let specialization = generic + .specialization(db) + .apply_specialization(db, specialization); + Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization)) + } + + Type::PropertyInstance(property) => { + Type::PropertyInstance(property.apply_specialization(db, specialization)) + } + + Type::Union(union) => union.map(db, |element| { + element.apply_specialization(db, specialization) + }), + Type::Intersection(intersection) => { + let mut builder = IntersectionBuilder::new(db); + for positive in intersection.positive(db) { + builder = + builder.add_positive(positive.apply_specialization(db, specialization)); + } + for negative in intersection.negative(db) { + builder = + builder.add_negative(negative.apply_specialization(db, specialization)); + } + builder.build() + } + Type::Tuple(tuple) => TupleType::from_elements( + db, + tuple + .iter(db) + .map(|ty| ty.apply_specialization(db, specialization)), + ), + + Type::Dynamic(_) + | Type::Never + | Type::AlwaysTruthy + | Type::AlwaysFalsy + | Type::WrapperDescriptor(_) + | Type::ModuleLiteral(_) + // A non-generic class never needs to be specialized. A generic class is specialized + // explicitly (via a subscript expression) or implicitly (via a call), and not because + // some other generic context's specialization is applied to it. + | Type::ClassLiteral(_) + // SubclassOf contains a ClassType, which has already been specialized if needed, like + // above with BoundMethod's self_instance. + | Type::SubclassOf(_) + | Type::IntLiteral(_) + | Type::BooleanLiteral(_) + | Type::LiteralString + | Type::StringLiteral(_) + | Type::BytesLiteral(_) + | Type::SliceLiteral(_) + // Instance contains a ClassType, which has already been specialized if needed, like + // above with BoundMethod's self_instance. + | Type::Instance(_) + | Type::KnownInstance(_) => self, + } + } + /// Return the string representation of this type when converted to string as it would be /// provided by the `__str__` method. /// @@ -4111,11 +4312,10 @@ impl<'db> Type<'db> { } Self::ModuleLiteral(module) => Some(TypeDefinition::Module(module.module(db))), Self::ClassLiteral(class_literal) => { - Some(TypeDefinition::Class(class_literal.class().definition(db))) - } - Self::Instance(instance) => { - Some(TypeDefinition::Class(instance.class().definition(db))) + Some(TypeDefinition::Class(class_literal.definition(db))) } + Self::GenericAlias(alias) => Some(TypeDefinition::Class(alias.definition(db))), + Self::Instance(instance) => Some(TypeDefinition::Class(instance.class.definition(db))), Self::KnownInstance(instance) => match instance { KnownInstanceType::TypeVar(var) => { Some(TypeDefinition::TypeVar(var.definition(db))) @@ -5133,6 +5333,11 @@ pub struct FunctionType<'db> { /// A set of special decorators that were applied to this function decorators: FunctionDecorators, + + /// A specialization that should be applied to the function's parameter and return types, + /// either because the function is itself generic, or because it appears in the body of a + /// generic class. + specialization: Option>, } #[salsa::tracked] @@ -5183,13 +5388,17 @@ impl<'db> FunctionType<'db> { /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(return_ref)] pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { - let internal_signature = self.internal_signature(db); + let mut internal_signature = self.internal_signature(db); if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { - Signature::todo("return type of overloaded function") - } else { - internal_signature + return Signature::todo("return type of overloaded function"); } + + if let Some(specialization) = self.specialization(db) { + internal_signature.apply_specialization(db, specialization); + } + + internal_signature } /// Typed internally-visible signature for this function. @@ -5212,6 +5421,21 @@ impl<'db> FunctionType<'db> { pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { self.known(db) == Some(known_function) } + + fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + let specialization = match self.specialization(db) { + Some(existing) => existing.apply_specialization(db, specialization), + None => specialization, + }; + Self::new( + db, + self.name(db).clone(), + self.known(db), + self.body_scope(db), + self.decorators(db), + Some(specialization), + ) + } } /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might @@ -5382,6 +5606,12 @@ impl<'db> CallableType<'db> { CallableType::new(db, Signature::new(parameters, return_ty)) } + fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + let mut signature = self.signature(db).clone(); + signature.apply_specialization(db, specialization); + Self::new(db, signature) + } + /// Returns `true` if this is a fully static callable type. /// /// A callable type is fully static if all of its parameters and return type are fully static @@ -6006,8 +6236,8 @@ impl<'db> TypeAliasType<'db> { /// Either the explicit `metaclass=` keyword of the class, or the inferred metaclass of one of its base classes. #[derive(Debug, Clone, PartialEq, Eq, salsa::Update)] pub(super) struct MetaclassCandidate<'db> { - metaclass: Class<'db>, - explicit_metaclass_of: Class<'db>, + metaclass: ClassType<'db>, + explicit_metaclass_of: ClassLiteralType<'db>, } #[salsa::interned(debug)] @@ -6579,6 +6809,10 @@ impl<'db> TupleType<'db> { pub fn len(&self, db: &'db dyn Db) -> usize { self.elements(db).len() } + + pub fn iter(&self, db: &'db dyn Db) -> impl Iterator> + 'db + '_ { + self.elements(db).iter().copied() + } } // Make sure that the `Type` enum does not grow unexpectedly. diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 8e91d024b9..c932474703 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -292,7 +292,7 @@ impl<'db> InnerIntersectionBuilder<'db> { _ => { let known_instance = new_positive .into_instance() - .and_then(|instance| instance.class().known(db)); + .and_then(|instance| instance.class.known(db)); if known_instance == Some(KnownClass::Object) { // `object & T` -> `T`; it is always redundant to add `object` to an intersection @@ -312,7 +312,7 @@ impl<'db> InnerIntersectionBuilder<'db> { new_positive = Type::BooleanLiteral(false); } Type::Instance(instance) - if instance.class().is_known(db, KnownClass::Bool) => + if instance.class.is_known(db, KnownClass::Bool) => { match new_positive { // `bool & AlwaysTruthy` -> `Literal[True]` @@ -406,7 +406,7 @@ impl<'db> InnerIntersectionBuilder<'db> { self.positive .iter() .filter_map(|ty| ty.into_instance()) - .filter_map(|instance| instance.class().known(db)) + .filter_map(|instance| instance.class.known(db)) .any(KnownClass::is_bool) }; @@ -422,7 +422,7 @@ impl<'db> InnerIntersectionBuilder<'db> { Type::Never => { // Adding ~Never to an intersection is a no-op. } - Type::Instance(instance) if instance.class().is_object(db) => { + Type::Instance(instance) if instance.class.is_object(db) => { // Adding ~object to an intersection results in Never. *self = Self::default(); self.positive.insert(Type::Never); diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index a2f25b1289..3bafa82cd1 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -18,8 +18,8 @@ use crate::types::diagnostic::{ }; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - todo_type, BoundMethodType, ClassLiteralType, FunctionDecorators, KnownClass, KnownFunction, - KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, + todo_type, BoundMethodType, FunctionDecorators, KnownClass, KnownFunction, KnownInstanceType, + MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{OldSecondaryDiagnosticMessage, Span}; use ruff_python_ast as ast; @@ -566,7 +566,7 @@ impl<'db> Bindings<'db> { _ => {} }, - Type::ClassLiteral(ClassLiteralType { class }) => match class.known(db) { + Type::ClassLiteral(class) => match class.known(db) { Some(KnownClass::Bool) => match overload.parameter_types() { [Some(arg)] => overload.set_return_type(arg.bool(db).into_type(db)), [None] => overload.set_return_type(Type::BooleanLiteral(false)), @@ -1064,7 +1064,7 @@ impl<'db> CallableDescription<'db> { }), Type::ClassLiteral(class_type) => Some(CallableDescription { kind: "class", - name: class_type.class().name(db), + name: class_type.name(db), }), Type::BoundMethod(bound_method) => Some(CallableDescription { kind: "bound method", diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index b73a62c792..a3fe1024b4 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -6,6 +6,7 @@ use super::{ Type, TypeAliasType, TypeQualifiers, TypeVarInstance, }; use crate::semantic_index::definition::Definition; +use crate::types::generics::{GenericContext, Specialization}; use crate::{ module_resolver::file_to_module, semantic_index::{ @@ -24,36 +25,199 @@ use crate::{ }; use indexmap::IndexSet; use itertools::Itertools as _; -use ruff_db::files::{File, FileRange}; +use ruff_db::files::File; use ruff_python_ast::{self as ast, PythonVersion}; use rustc_hash::FxHashSet; -/// Representation of a runtime class object. -/// -/// Does not in itself represent a type, -/// but is used as the inner data for several structs that *do* represent types. -#[salsa::interned(debug)] +fn explicit_bases_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &[Type<'db>], + _count: u32, + _self: ClassLiteralType<'db>, +) -> salsa::CycleRecoveryAction]>> { + salsa::CycleRecoveryAction::Iterate +} + +fn explicit_bases_cycle_initial<'db>( + _db: &'db dyn Db, + _self: ClassLiteralType<'db>, +) -> Box<[Type<'db>]> { + Box::default() +} + +fn try_mro_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Result, MroError<'db>>, + _count: u32, + _self: ClassLiteralType<'db>, + _specialization: Option>, +) -> salsa::CycleRecoveryAction, MroError<'db>>> { + salsa::CycleRecoveryAction::Iterate +} + +#[allow(clippy::unnecessary_wraps)] +fn try_mro_cycle_initial<'db>( + db: &'db dyn Db, + self_: ClassLiteralType<'db>, + specialization: Option>, +) -> Result, MroError<'db>> { + Ok(Mro::from_error( + db, + self_.apply_optional_specialization(db, specialization), + )) +} + +#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)] +fn inheritance_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Option, + _count: u32, + _self: ClassLiteralType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn inheritance_cycle_initial<'db>( + _db: &'db dyn Db, + _self: ClassLiteralType<'db>, +) -> Option { + None +} + +/// Representation of a class definition statement in the AST. This does not in itself represent a +/// type, but is used as the inner data for several structs that *do* represent types. +#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)] pub struct Class<'db> { /// Name of the class at definition - #[return_ref] pub(crate) name: ast::name::Name, - body_scope: ScopeId<'db>, + pub(crate) body_scope: ScopeId<'db>, pub(crate) known: Option, } -#[salsa::tracked] impl<'db> Class<'db> { + fn file(&self, db: &dyn Db) -> File { + self.body_scope.file(db) + } + + /// Return the original [`ast::StmtClassDef`] node associated with this class + /// + /// ## Note + /// Only call this function from queries in the same file or your + /// query depends on the AST of another file (bad!). + fn node(&self, db: &'db dyn Db) -> &'db ast::StmtClassDef { + self.body_scope.node(db).expect_class() + } + + fn definition(&self, db: &'db dyn Db) -> Definition<'db> { + let index = semantic_index(db, self.body_scope.file(db)); + index.expect_single_definition(self.body_scope.node(db).expect_class()) + } +} + +/// A [`Class`] that is not generic. +#[salsa::interned(debug)] +pub struct NonGenericClass<'db> { + #[return_ref] + pub(crate) class: Class<'db>, +} + +impl<'db> From> for Type<'db> { + fn from(class: NonGenericClass<'db>) -> Type<'db> { + Type::ClassLiteral(ClassLiteralType::NonGeneric(class)) + } +} + +/// A [`Class`] that is generic. +#[salsa::interned(debug)] +pub struct GenericClass<'db> { + #[return_ref] + pub(crate) class: Class<'db>, + pub(crate) generic_context: GenericContext<'db>, +} + +impl<'db> From> for Type<'db> { + fn from(class: GenericClass<'db>) -> Type<'db> { + Type::ClassLiteral(ClassLiteralType::Generic(class)) + } +} + +/// A specialization of a generic class with a particular assignment of types to typevars. +#[salsa::interned(debug)] +pub struct GenericAlias<'db> { + pub(crate) origin: GenericClass<'db>, + pub(crate) specialization: Specialization<'db>, +} + +impl<'db> GenericAlias<'db> { pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { - let scope = self.body_scope(db); - let index = semantic_index(db, scope.file(db)); - index.expect_single_definition(scope.node(db).expect_class()) + self.origin(db).class(db).definition(db) + } +} + +impl<'db> From> for Type<'db> { + fn from(alias: GenericAlias<'db>) -> Type<'db> { + Type::GenericAlias(alias) + } +} + +/// Represents a class type, which might be a non-generic class, or a specialization of a generic +/// class. +#[derive( + Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Supertype, salsa::Update, +)] +pub enum ClassType<'db> { + NonGeneric(NonGenericClass<'db>), + Generic(GenericAlias<'db>), +} + +#[salsa::tracked] +impl<'db> ClassType<'db> { + fn class(self, db: &'db dyn Db) -> &'db Class<'db> { + match self { + Self::NonGeneric(non_generic) => non_generic.class(db), + Self::Generic(generic) => generic.origin(db).class(db), + } + } + + /// Returns the class literal and specialization for this class. For a non-generic class, this + /// is the class itself. For a generic alias, this is the alias's origin. + pub(crate) fn class_literal( + self, + db: &'db dyn Db, + ) -> (ClassLiteralType<'db>, Option>) { + match self { + Self::NonGeneric(non_generic) => (ClassLiteralType::NonGeneric(non_generic), None), + Self::Generic(generic) => ( + ClassLiteralType::Generic(generic.origin(db)), + Some(generic.specialization(db)), + ), + } + } + + pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { + &self.class(db).name + } + + pub(crate) fn known(self, db: &'db dyn Db) -> Option { + self.class(db).known + } + + pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { + self.class(db).definition(db) + } + + fn specialize_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + match self { + Self::NonGeneric(_) => ty, + Self::Generic(generic) => ty.apply_specialization(db, generic.specialization(db)), + } } /// Return `true` if this class represents `known_class` pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { - self.known(db) == Some(known_class) + self.class(db).known == Some(known_class) } /// Return `true` if this class represents the builtin class `object` @@ -61,6 +225,202 @@ impl<'db> Class<'db> { self.is_known(db, KnownClass::Object) } + /// Iterate over the [method resolution order] ("MRO") of the class. + /// + /// If the MRO could not be accurately resolved, this method falls back to iterating + /// over an MRO that has the class directly inheriting from `Unknown`. Use + /// [`ClassLiteralType::try_mro`] if you need to distinguish between the success and failure + /// cases rather than simply iterating over the inferred resolution order for the class. + /// + /// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order + pub(super) fn iter_mro(self, db: &'db dyn Db) -> impl Iterator> { + let (class_literal, specialization) = self.class_literal(db); + class_literal.iter_mro(db, specialization) + } + + /// Is this class final? + pub(super) fn is_final(self, db: &'db dyn Db) -> bool { + let (class_literal, _) = self.class_literal(db); + class_literal.is_final(db) + } + + /// Return `true` if `other` is present in this class's MRO. + pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool { + // `is_subclass_of` is checking the subtype relation, in which gradual types do not + // participate, so we should not return `True` if we find `Any/Unknown` in the MRO. + self.iter_mro(db).contains(&ClassBase::Class(other)) + } + + /// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred. + pub(super) fn metaclass(self, db: &'db dyn Db) -> Type<'db> { + let (class_literal, _) = self.class_literal(db); + self.specialize_type(db, class_literal.metaclass(db)) + } + + /// Return a type representing "the set of all instances of the metaclass of this class". + pub(super) fn metaclass_instance_type(self, db: &'db dyn Db) -> Type<'db> { + self + .metaclass(db) + .to_instance(db) + .expect("`Type::to_instance()` should always return `Some()` when called on the type of a metaclass") + } + + /// Returns the class member of this class named `name`. + /// + /// The member resolves to a member on the class itself or any of its proper superclasses. + /// + /// TODO: Should this be made private...? + pub(super) fn class_member( + self, + db: &'db dyn Db, + name: &str, + policy: MemberLookupPolicy, + ) -> SymbolAndQualifiers<'db> { + let (class_literal, specialization) = self.class_literal(db); + class_literal + .class_member(db, specialization, name, policy) + .map_type(|ty| self.specialize_type(db, ty)) + } + + /// Returns the inferred type of the class member named `name`. Only bound members + /// or those marked as ClassVars are considered. + /// + /// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope + /// directly. Use [`ClassType::class_member`] if you require a method that will + /// traverse through the MRO until it finds the member. + pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { + let (class_literal, _) = self.class_literal(db); + class_literal + .own_class_member(db, name) + .map_type(|ty| self.specialize_type(db, ty)) + } + + /// Returns the `name` attribute of an instance of this class. + /// + /// The attribute could be defined in the class body, but it could also be an implicitly + /// defined attribute that is only present in a method (typically `__init__`). + /// + /// The attribute might also be defined in a superclass of this class. + pub(super) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { + let (class_literal, specialization) = self.class_literal(db); + class_literal + .instance_member(db, specialization, name) + .map_type(|ty| self.specialize_type(db, ty)) + } + + /// A helper function for `instance_member` that looks up the `name` attribute only on + /// this class, not on its superclasses. + fn own_instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { + let (class_literal, _) = self.class_literal(db); + class_literal + .own_instance_member(db, name) + .map_type(|ty| self.specialize_type(db, ty)) + } +} + +impl<'db> From> for ClassType<'db> { + fn from(generic: GenericAlias<'db>) -> ClassType<'db> { + ClassType::Generic(generic) + } +} + +impl<'db> From> for Type<'db> { + fn from(class: ClassType<'db>) -> Type<'db> { + match class { + ClassType::NonGeneric(non_generic) => non_generic.into(), + ClassType::Generic(generic) => generic.into(), + } + } +} + +/// Represents a single class object at runtime, which might be a non-generic class, or a generic +/// class that has not been specialized. +#[derive( + Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Supertype, salsa::Update, +)] +pub enum ClassLiteralType<'db> { + NonGeneric(NonGenericClass<'db>), + Generic(GenericClass<'db>), +} + +#[salsa::tracked] +impl<'db> ClassLiteralType<'db> { + fn class(self, db: &'db dyn Db) -> &'db Class<'db> { + match self { + Self::NonGeneric(non_generic) => non_generic.class(db), + Self::Generic(generic) => generic.class(db), + } + } + + pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { + &self.class(db).name + } + + pub(crate) fn known(self, db: &'db dyn Db) -> Option { + self.class(db).known + } + + /// Return `true` if this class represents `known_class` + pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { + self.class(db).known == Some(known_class) + } + + /// Return `true` if this class represents the builtin class `object` + pub(crate) fn is_object(self, db: &'db dyn Db) -> bool { + self.is_known(db, KnownClass::Object) + } + + pub(crate) fn body_scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.class(db).body_scope + } + + pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { + self.class(db).definition(db) + } + + pub(crate) fn apply_optional_specialization( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> ClassType<'db> { + match (self, specialization) { + (Self::NonGeneric(non_generic), _) => ClassType::NonGeneric(non_generic), + (Self::Generic(generic), None) => { + let specialization = generic.generic_context(db).default_specialization(db); + ClassType::Generic(GenericAlias::new(db, generic, specialization)) + } + (Self::Generic(generic), Some(specialization)) => { + ClassType::Generic(GenericAlias::new(db, generic, specialization)) + } + } + } + + /// Returns the default specialization of this class. For non-generic classes, the class is + /// returned unchanged. For a non-specialized generic class, we return a generic alias that + /// applies the default specialization to the class's typevars. + pub(crate) fn default_specialization(self, db: &'db dyn Db) -> ClassType<'db> { + match self { + Self::NonGeneric(non_generic) => ClassType::NonGeneric(non_generic), + Self::Generic(generic) => { + let specialization = generic.generic_context(db).default_specialization(db); + ClassType::Generic(GenericAlias::new(db, generic, specialization)) + } + } + } + + /// Returns the unknown specialization of this class. For non-generic classes, the class is + /// returned unchanged. For a non-specialized generic class, we return a generic alias that + /// maps each of the class's typevars to `Unknown`. + pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> ClassType<'db> { + match self { + Self::NonGeneric(non_generic) => ClassType::NonGeneric(non_generic), + Self::Generic(generic) => { + let specialization = generic.generic_context(db).unknown_specialization(db); + ClassType::Generic(GenericAlias::new(db, generic, specialization)) + } + } + } + /// Return an iterator over the inferred types of this class's *explicit* bases. /// /// Note that any class (except for `object`) that has no explicit @@ -79,21 +439,21 @@ impl<'db> Class<'db> { } /// Iterate over this class's explicit bases, filtering out any bases that are not class objects. - fn fully_static_explicit_bases(self, db: &'db dyn Db) -> impl Iterator> { + fn fully_static_explicit_bases(self, db: &'db dyn Db) -> impl Iterator> { self.explicit_bases(db) .iter() .copied() - .filter_map(Type::into_class_literal) - .map(|ClassLiteralType { class }| class) + .filter_map(Type::into_class_type) } #[salsa::tracked(return_ref, cycle_fn=explicit_bases_cycle_recover, cycle_initial=explicit_bases_cycle_initial)] fn explicit_bases_query(self, db: &'db dyn Db) -> Box<[Type<'db>]> { - tracing::trace!("Class::explicit_bases_query: {}", self.name(db)); + let class = self.class(db); + tracing::trace!("ClassLiteralType::explicit_bases_query: {}", class.name); - let class_stmt = self.node(db); + let class_stmt = class.node(db); let class_definition = - semantic_index(db, self.file(db)).expect_single_definition(class_stmt); + semantic_index(db, class.file(db)).expect_single_definition(class_stmt); class_stmt .bases() @@ -102,40 +462,19 @@ impl<'db> Class<'db> { .collect() } - fn file(self, db: &dyn Db) -> File { - self.body_scope(db).file(db) - } - - /// Return the original [`ast::StmtClassDef`] node associated with this class - /// - /// ## Note - /// Only call this function from queries in the same file or your - /// query depends on the AST of another file (bad!). - fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef { - self.body_scope(db).node(db).expect_class() - } - - /// Returns the file range of the class's name. - pub fn focus_range(self, db: &dyn Db) -> FileRange { - FileRange::new(self.file(db), self.node(db).name.range) - } - - pub fn full_range(self, db: &dyn Db) -> FileRange { - FileRange::new(self.file(db), self.node(db).range) - } - /// Return the types of the decorators on this class #[salsa::tracked(return_ref)] fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> { - tracing::trace!("Class::decorators: {}", self.name(db)); + let class = self.class(db); + tracing::trace!("ClassLiteralType::decorators: {}", class.name); - let class_stmt = self.node(db); + let class_stmt = class.node(db); if class_stmt.decorator_list.is_empty() { return Box::new([]); } let class_definition = - semantic_index(db, self.file(db)).expect_single_definition(class_stmt); + semantic_index(db, class.file(db)).expect_single_definition(class_stmt); class_stmt .decorator_list @@ -164,28 +503,43 @@ impl<'db> Class<'db> { /// /// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order #[salsa::tracked(return_ref, cycle_fn=try_mro_cycle_recover, cycle_initial=try_mro_cycle_initial)] - pub(super) fn try_mro(self, db: &'db dyn Db) -> Result, MroError<'db>> { - tracing::trace!("Class::try_mro: {}", self.name(db)); - Mro::of_class(db, self) + pub(super) fn try_mro( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> Result, MroError<'db>> { + let class = self.class(db); + tracing::trace!("ClassLiteralType::try_mro: {}", class.name); + Mro::of_class(db, self, specialization) } /// Iterate over the [method resolution order] ("MRO") of the class. /// /// If the MRO could not be accurately resolved, this method falls back to iterating /// over an MRO that has the class directly inheriting from `Unknown`. Use - /// [`Class::try_mro`] if you need to distinguish between the success and failure + /// [`ClassLiteralType::try_mro`] if you need to distinguish between the success and failure /// cases rather than simply iterating over the inferred resolution order for the class. /// /// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order - pub(super) fn iter_mro(self, db: &'db dyn Db) -> impl Iterator> { - MroIterator::new(db, self) + pub(super) fn iter_mro( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> impl Iterator> { + MroIterator::new(db, self, specialization) } /// Return `true` if `other` is present in this class's MRO. - pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: Class) -> bool { + pub(super) fn is_subclass_of( + self, + db: &'db dyn Db, + specialization: Option>, + other: ClassType<'db>, + ) -> bool { // `is_subclass_of` is checking the subtype relation, in which gradual types do not // participate, so we should not return `True` if we find `Any/Unknown` in the MRO. - self.iter_mro(db).contains(&ClassBase::Class(other)) + self.iter_mro(db, specialization) + .contains(&ClassBase::Class(other)) } /// Return the explicit `metaclass` of this class, if one is defined. @@ -194,14 +548,15 @@ impl<'db> Class<'db> { /// Only call this function from queries in the same file or your /// query depends on the AST of another file (bad!). fn explicit_metaclass(self, db: &'db dyn Db) -> Option> { - let class_stmt = self.node(db); + let class = self.class(db); + let class_stmt = class.node(db); let metaclass_node = &class_stmt .arguments .as_ref()? .find_keyword("metaclass")? .value; - let class_definition = self.definition(db); + let class_definition = class.definition(db); Some(definition_expression_type( db, @@ -227,7 +582,8 @@ impl<'db> Class<'db> { /// Return the metaclass of this class, or an error if the metaclass cannot be inferred. #[salsa::tracked] pub(super) fn try_metaclass(self, db: &'db dyn Db) -> Result, MetaclassError<'db>> { - tracing::trace!("Class::try_metaclass: {}", self.name(db)); + let class = self.class(db); + tracing::trace!("ClassLiteralType::try_metaclass: {}", class.name); // Identify the class's own metaclass (or take the first base class's metaclass). let mut base_classes = self.fully_static_explicit_bases(db).peekable(); @@ -243,18 +599,19 @@ impl<'db> Class<'db> { let (metaclass, class_metaclass_was_from) = if let Some(metaclass) = explicit_metaclass { (metaclass, self) } else if let Some(base_class) = base_classes.next() { - (base_class.metaclass(db), base_class) + let (base_class_literal, _) = base_class.class_literal(db); + (base_class.metaclass(db), base_class_literal) } else { (KnownClass::Type.to_class_literal(db), self) }; - let mut candidate = if let Type::ClassLiteral(metaclass_ty) = metaclass { + let mut candidate = if let Some(metaclass_ty) = metaclass.into_class_type() { MetaclassCandidate { - metaclass: metaclass_ty.class, + metaclass: metaclass_ty, explicit_metaclass_of: class_metaclass_was_from, } } else { - let name = Type::string_literal(db, self.name(db)); + let name = Type::string_literal(db, &class.name); let bases = TupleType::from_elements(db, self.explicit_bases(db)); // TODO: Should be `dict[str, Any]` let namespace = KnownClass::Dict.to_instance(db); @@ -290,32 +647,34 @@ impl<'db> Class<'db> { // - https://github.com/python/cpython/blob/83ba8c2bba834c0b92de669cac16fcda17485e0e/Objects/typeobject.c#L3629-L3663 for base_class in base_classes { let metaclass = base_class.metaclass(db); - let Type::ClassLiteral(metaclass) = metaclass else { + let Some(metaclass) = metaclass.into_class_type() else { continue; }; - if metaclass.class.is_subclass_of(db, candidate.metaclass) { + if metaclass.is_subclass_of(db, candidate.metaclass) { + let (base_class_literal, _) = base_class.class_literal(db); candidate = MetaclassCandidate { - metaclass: metaclass.class, - explicit_metaclass_of: base_class, + metaclass, + explicit_metaclass_of: base_class_literal, }; continue; } - if candidate.metaclass.is_subclass_of(db, metaclass.class) { + if candidate.metaclass.is_subclass_of(db, metaclass) { continue; } + let (base_class_literal, _) = base_class.class_literal(db); return Err(MetaclassError { kind: MetaclassErrorKind::Conflict { candidate1: candidate, candidate2: MetaclassCandidate { - metaclass: metaclass.class, - explicit_metaclass_of: base_class, + metaclass, + explicit_metaclass_of: base_class_literal, }, candidate1_is_base_class: explicit_metaclass.is_none(), }, }); } - Ok(Type::class_literal(candidate.metaclass)) + Ok(candidate.metaclass.into()) } /// Returns the class member of this class named `name`. @@ -326,11 +685,12 @@ impl<'db> Class<'db> { pub(super) fn class_member( self, db: &'db dyn Db, + specialization: Option>, name: &str, policy: MemberLookupPolicy, ) -> SymbolAndQualifiers<'db> { if name == "__mro__" { - let tuple_elements = self.iter_mro(db).map(Type::from); + let tuple_elements = self.iter_mro(db, specialization).map(Type::from); return Symbol::bound(TupleType::from_elements(db, tuple_elements)).into(); } @@ -345,7 +705,7 @@ impl<'db> Class<'db> { let mut lookup_result: LookupResult<'db> = Err(LookupError::Unbound(TypeQualifiers::empty())); - for superclass in self.iter_mro(db) { + for superclass in self.iter_mro(db, specialization) { match superclass { ClassBase::Dynamic(DynamicType::TodoProtocol) => { // TODO: We currently skip `Protocol` when looking up class members, in order to @@ -415,7 +775,7 @@ impl<'db> Class<'db> { /// or those marked as ClassVars are considered. /// /// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope - /// directly. Use [`Class::class_member`] if you require a method that will + /// directly. Use [`ClassLiteralType::class_member`] if you require a method that will /// traverse through the MRO until it finds the member. pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { let body_scope = self.body_scope(db); @@ -428,11 +788,16 @@ impl<'db> Class<'db> { /// defined attribute that is only present in a method (typically `__init__`). /// /// The attribute might also be defined in a superclass of this class. - pub(super) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { + pub(super) fn instance_member( + self, + db: &'db dyn Db, + specialization: Option>, + name: &str, + ) -> SymbolAndQualifiers<'db> { let mut union = UnionBuilder::new(db); let mut union_qualifiers = TypeQualifiers::empty(); - for superclass in self.iter_mro(db) { + for superclass in self.iter_mro(db, specialization) { match superclass { ClassBase::Dynamic(DynamicType::TodoProtocol) => { // TODO: We currently skip `Protocol` when looking up instance members, in order to @@ -680,21 +1045,22 @@ impl<'db> Class<'db> { /// Also, populates `visited_classes` with all base classes of `self`. fn is_cyclically_defined_recursive<'db>( db: &'db dyn Db, - class: Class<'db>, - classes_on_stack: &mut IndexSet>, - visited_classes: &mut IndexSet>, + class: ClassLiteralType<'db>, + classes_on_stack: &mut IndexSet>, + visited_classes: &mut IndexSet>, ) -> bool { let mut result = false; for explicit_base_class in class.fully_static_explicit_bases(db) { - if !classes_on_stack.insert(explicit_base_class) { + let (explicit_base_class_literal, _) = explicit_base_class.class_literal(db); + if !classes_on_stack.insert(explicit_base_class_literal) { return true; } - if visited_classes.insert(explicit_base_class) { + if visited_classes.insert(explicit_base_class_literal) { // If we find a cycle, keep searching to check if we can reach the starting class. result |= is_cyclically_defined_recursive( db, - explicit_base_class, + explicit_base_class_literal, classes_on_stack, visited_classes, ); @@ -718,48 +1084,13 @@ impl<'db> Class<'db> { } } -fn explicit_bases_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &[Type<'db>], - _count: u32, - _self: Class<'db>, -) -> salsa::CycleRecoveryAction]>> { - salsa::CycleRecoveryAction::Iterate -} - -fn explicit_bases_cycle_initial<'db>(_db: &'db dyn Db, _self: Class<'db>) -> Box<[Type<'db>]> { - Box::default() -} - -fn try_mro_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Result, MroError<'db>>, - _count: u32, - _self: Class<'db>, -) -> salsa::CycleRecoveryAction, MroError<'db>>> { - salsa::CycleRecoveryAction::Iterate -} - -#[allow(clippy::unnecessary_wraps)] -fn try_mro_cycle_initial<'db>( - db: &'db dyn Db, - self_: Class<'db>, -) -> Result, MroError<'db>> { - Ok(Mro::from_error(db, self_)) -} - -#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)] -fn inheritance_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Option, - _count: u32, - _self: Class<'db>, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate -} - -fn inheritance_cycle_initial<'db>(_db: &'db dyn Db, _self: Class<'db>) -> Option { - None +impl<'db> From> for Type<'db> { + fn from(class: ClassLiteralType<'db>) -> Type<'db> { + match class { + ClassLiteralType::NonGeneric(non_generic) => non_generic.into(), + ClassLiteralType::Generic(generic) => generic.into(), + } + } } #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -778,48 +1109,13 @@ impl InheritanceCycle { } } -/// A singleton type representing a single class object at runtime. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)] -pub struct ClassLiteralType<'db> { - pub(super) class: Class<'db>, -} - -impl<'db> ClassLiteralType<'db> { - pub(super) fn class(self) -> Class<'db> { - self.class - } - - pub(crate) fn body_scope(self, db: &'db dyn Db) -> ScopeId<'db> { - self.class.body_scope(db) - } - - pub(super) fn class_member( - self, - db: &'db dyn Db, - name: &str, - policy: MemberLookupPolicy, - ) -> SymbolAndQualifiers<'db> { - self.class.class_member(db, name, policy) - } -} - -impl<'db> From> for Type<'db> { - fn from(value: ClassLiteralType<'db>) -> Self { - Self::ClassLiteral(value) - } -} - /// A type representing the set of runtime objects which are instances of a certain class. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update)] pub struct InstanceType<'db> { - pub(super) class: Class<'db>, + pub class: ClassType<'db>, } impl<'db> InstanceType<'db> { - pub(super) fn class(self) -> Class<'db> { - self.class - } - pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: InstanceType<'db>) -> bool { // N.B. The subclass relation is fully static self.class.is_subclass_of(db, other.class) @@ -839,7 +1135,7 @@ impl<'db> From> for Type<'db> { /// Note: good candidates are any classes in `[crate::module_resolver::module::KnownModule]` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(test, derive(strum_macros::EnumIter))] -pub enum KnownClass { +pub(crate) enum KnownClass { // To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored` // and grep for the symbol name in any `.pyi` file. @@ -1081,8 +1377,8 @@ impl<'db> KnownClass { /// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this. pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> { self.to_class_literal(db) - .into_class_literal() - .map(|ClassLiteralType { class }| Type::instance(class)) + .into_class_type() + .map(Type::instance) .unwrap_or_else(Type::unknown) } @@ -1096,9 +1392,9 @@ impl<'db> KnownClass { ) -> Result, KnownClassLookupError<'db>> { let symbol = known_module_symbol(db, self.canonical_module(db), self.name(db)).symbol; match symbol { - Symbol::Type(Type::ClassLiteral(class_type), Boundness::Bound) => Ok(class_type), - Symbol::Type(Type::ClassLiteral(class_type), Boundness::PossiblyUnbound) => { - Err(KnownClassLookupError::ClassPossiblyUnbound { class_type }) + Symbol::Type(Type::ClassLiteral(class_literal), Boundness::Bound) => Ok(class_literal), + Symbol::Type(Type::ClassLiteral(class_literal), Boundness::PossiblyUnbound) => { + Err(KnownClassLookupError::ClassPossiblyUnbound { class_literal }) } Symbol::Type(found_type, _) => { Err(KnownClassLookupError::SymbolNotAClass { found_type }) @@ -1133,8 +1429,8 @@ impl<'db> KnownClass { } match lookup_error { - KnownClassLookupError::ClassPossiblyUnbound { class_type, .. } => { - Type::class_literal(class_type.class) + KnownClassLookupError::ClassPossiblyUnbound { class_literal, .. } => { + class_literal.into() } KnownClassLookupError::ClassNotFound { .. } | KnownClassLookupError::SymbolNotAClass { .. } => Type::unknown(), @@ -1148,16 +1444,16 @@ impl<'db> KnownClass { /// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this. pub(crate) fn to_subclass_of(self, db: &'db dyn Db) -> Type<'db> { self.to_class_literal(db) - .into_class_literal() - .map(|ClassLiteralType { class }| SubclassOfType::from(db, class)) + .into_class_type() + .map(|class| SubclassOfType::from(db, class)) .unwrap_or_else(SubclassOfType::subclass_of_unknown) } /// Return `true` if this symbol can be resolved to a class definition `class` in typeshed, /// *and* `class` is a subclass of `other`. - pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: Class<'db>) -> bool { + pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool { self.try_to_class_literal(db) - .is_ok_and(|ClassLiteralType { class }| class.is_subclass_of(db, other)) + .is_ok_and(|class| class.is_subclass_of(db, None, other)) } /// Return the module in which we should look up the definition for this class @@ -1489,7 +1785,9 @@ pub(crate) enum KnownClassLookupError<'db> { SymbolNotAClass { found_type: Type<'db> }, /// There is a symbol by that name in the expected typeshed module, /// and it's a class definition, but it's possibly unbound. - ClassPossiblyUnbound { class_type: ClassLiteralType<'db> }, + ClassPossiblyUnbound { + class_literal: ClassLiteralType<'db>, + }, } impl<'db> KnownClassLookupError<'db> { @@ -1769,7 +2067,7 @@ impl<'db> KnownInstanceType<'db> { } /// Return `true` if this symbol is an instance of `class`. - pub(super) fn is_instance_of(self, db: &'db dyn Db, class: Class<'db>) -> bool { + pub(super) fn is_instance_of(self, db: &'db dyn Db, class: ClassType<'db>) -> bool { self.class().is_subclass_of(db, class) } diff --git a/crates/red_knot_python_semantic/src/types/class_base.rs b/crates/red_knot_python_semantic/src/types/class_base.rs index d55f10192f..26e563316a 100644 --- a/crates/red_knot_python_semantic/src/types/class_base.rs +++ b/crates/red_knot_python_semantic/src/types/class_base.rs @@ -1,16 +1,19 @@ -use crate::types::{todo_type, Class, DynamicType, KnownClass, KnownInstanceType, Type}; +use crate::types::{todo_type, ClassType, DynamicType, KnownClass, KnownInstanceType, Type}; use crate::Db; use itertools::Either; /// Enumeration of the possible kinds of types we allow in class bases. /// -/// This is much more limited than the [`Type`] enum: -/// all types that would be invalid to have as a class base are -/// transformed into [`ClassBase::unknown`] +/// This is much more limited than the [`Type`] enum: all types that would be invalid to have as a +/// class base are transformed into [`ClassBase::unknown`] +/// +/// Note that a non-specialized generic class _cannot_ be a class base. When we see a +/// non-specialized generic class in any type expression (including the list of base classes), we +/// automatically construct the default specialization for that class. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)] pub(crate) enum ClassBase<'db> { Dynamic(DynamicType), - Class(Class<'db>), + Class(ClassType<'db>), } impl<'db> ClassBase<'db> { @@ -39,7 +42,12 @@ impl<'db> ClassBase<'db> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.base { ClassBase::Dynamic(dynamic) => dynamic.fmt(f), - ClassBase::Class(class) => write!(f, "", class.name(self.db)), + ClassBase::Class(class @ ClassType::NonGeneric(_)) => { + write!(f, "", class.name(self.db)) + } + ClassBase::Class(ClassType::Generic(alias)) => { + write!(f, "", alias.display(self.db)) + } } } } @@ -51,8 +59,8 @@ impl<'db> ClassBase<'db> { pub(super) fn object(db: &'db dyn Db) -> Self { KnownClass::Object .to_class_literal(db) - .into_class_literal() - .map_or(Self::unknown(), |literal| Self::Class(literal.class())) + .into_class_type() + .map_or(Self::unknown(), Self::Class) } /// Attempt to resolve `ty` into a `ClassBase`. @@ -61,11 +69,12 @@ impl<'db> ClassBase<'db> { pub(super) fn try_from_type(db: &'db dyn Db, ty: Type<'db>) -> Option { match ty { Type::Dynamic(dynamic) => Some(Self::Dynamic(dynamic)), - Type::ClassLiteral(literal) => Some(if literal.class().is_known(db, KnownClass::Any) { + Type::ClassLiteral(literal) => Some(if literal.is_known(db, KnownClass::Any) { Self::Dynamic(DynamicType::Any) } else { - Self::Class(literal.class()) + Self::Class(literal.default_specialization(db)) }), + Type::GenericAlias(generic) => Some(Self::Class(ClassType::Generic(generic))), Type::Union(_) => None, // TODO -- forces consideration of multiple possible MROs? Type::Intersection(_) => None, // TODO -- probably incorrect? Type::Instance(_) => None, // TODO -- handle `__mro_entries__`? @@ -159,7 +168,7 @@ impl<'db> ClassBase<'db> { } } - pub(super) fn into_class(self) -> Option> { + pub(super) fn into_class(self) -> Option> { match self { Self::Class(class) => Some(class), Self::Dynamic(_) => None, @@ -178,8 +187,8 @@ impl<'db> ClassBase<'db> { } } -impl<'db> From> for ClassBase<'db> { - fn from(value: Class<'db>) -> Self { +impl<'db> From> for ClassBase<'db> { + fn from(value: ClassType<'db>) -> Self { ClassBase::Class(value) } } @@ -188,7 +197,7 @@ impl<'db> From> for Type<'db> { fn from(value: ClassBase<'db>) -> Self { match value { ClassBase::Dynamic(dynamic) => Type::Dynamic(dynamic), - ClassBase::Class(class) => Type::class_literal(class), + ClassBase::Class(class) => class.into(), } } } diff --git a/crates/red_knot_python_semantic/src/types/diagnostic.rs b/crates/red_knot_python_semantic/src/types/diagnostic.rs index 4dbff6cc56..154aa3dd71 100644 --- a/crates/red_knot_python_semantic/src/types/diagnostic.rs +++ b/crates/red_knot_python_semantic/src/types/diagnostic.rs @@ -7,7 +7,7 @@ use crate::types::string_annotation::{ IMPLICIT_CONCATENATED_STRING_TYPE_ANNOTATION, INVALID_SYNTAX_IN_FORWARD_ANNOTATION, RAW_STRING_TYPE_ANNOTATION, }; -use crate::types::{ClassLiteralType, KnownInstanceType, Type}; +use crate::types::{KnownInstanceType, Type}; use ruff_db::diagnostic::{Diagnostic, OldSecondaryDiagnosticMessage, Span}; use ruff_python_ast::{self as ast, AnyNodeRef}; use ruff_text_size::Ranged; @@ -1025,7 +1025,7 @@ fn report_invalid_assignment_with_message( message: std::fmt::Arguments, ) { match target_ty { - Type::ClassLiteral(ClassLiteralType { class }) => { + Type::ClassLiteral(class) => { context.report_lint(&INVALID_ASSIGNMENT, node, format_args!( "Implicit shadowing of class `{}`; annotate to make it explicit if this is intentional", class.name(context.db()))); diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 7c5fac3c98..87825bc5a6 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -6,11 +6,13 @@ use ruff_db::display::FormatterJoinExtension; use ruff_python_ast::str::{Quote, TripleQuotes}; use ruff_python_literal::escape::AsciiEscape; +use crate::types::class::{ClassType, GenericAlias, GenericClass}; use crate::types::class_base::ClassBase; +use crate::types::generics::Specialization; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ - ClassLiteralType, InstanceType, IntersectionType, KnownClass, MethodWrapperKind, - StringLiteralType, Type, UnionType, WrapperDescriptorKind, + InstanceType, IntersectionType, KnownClass, MethodWrapperKind, StringLiteralType, Type, + TypeVarInstance, UnionType, WrapperDescriptorKind, }; use crate::Db; use rustc_hash::FxHashMap; @@ -34,7 +36,7 @@ impl Display for DisplayType<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let representation = self.ty.representation(self.db); match self.ty { - Type::ClassLiteral(literal) if literal.class().is_known(self.db, KnownClass::Any) => { + Type::ClassLiteral(literal) if literal.is_known(self.db, KnownClass::Any) => { write!(f, "typing.Any") } Type::IntLiteral(_) @@ -42,6 +44,7 @@ impl Display for DisplayType<'_> { | Type::StringLiteral(_) | Type::BytesLiteral(_) | Type::ClassLiteral(_) + | Type::GenericAlias(_) | Type::FunctionLiteral(_) => { write!(f, "Literal[{representation}]") } @@ -69,20 +72,21 @@ impl Display for DisplayRepresentation<'_> { match self.ty { Type::Dynamic(dynamic) => dynamic.fmt(f), Type::Never => f.write_str("Never"), - Type::Instance(InstanceType { class }) => { - let representation = match class.known(self.db) { - Some(KnownClass::NoneType) => "None", - Some(KnownClass::NoDefaultType) => "NoDefault", - _ => class.name(self.db), - }; - f.write_str(representation) - } + Type::Instance(InstanceType { class }) => match (class, class.known(self.db)) { + (_, Some(KnownClass::NoneType)) => f.write_str("None"), + (_, Some(KnownClass::NoDefaultType)) => f.write_str("NoDefault"), + (ClassType::NonGeneric(class), _) => f.write_str(&class.class(self.db).name), + (ClassType::Generic(alias), _) => write!(f, "{}", alias.display(self.db)), + }, Type::PropertyInstance(_) => f.write_str("property"), Type::ModuleLiteral(module) => { write!(f, "", module.module(self.db).name()) } // TODO functions and classes should display using a fully qualified name - Type::ClassLiteral(ClassLiteralType { class }) => f.write_str(class.name(self.db)), + Type::ClassLiteral(class) => f.write_str(class.name(self.db)), + Type::GenericAlias(generic) => { + write!(f, "{}", generic.display(self.db)) + } Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() { // Only show the bare class name here; ClassBase::display would render this as // type[] instead of type[Foo]. @@ -90,21 +94,49 @@ impl Display for DisplayRepresentation<'_> { ClassBase::Dynamic(dynamic) => write!(f, "type[{dynamic}]"), }, Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)), - Type::FunctionLiteral(function) => f.write_str(function.name(self.db)), + Type::FunctionLiteral(function) => { + f.write_str(function.name(self.db))?; + if let Some(specialization) = function.specialization(self.db) { + specialization.display_short(self.db).fmt(f)?; + } + Ok(()) + } Type::Callable(callable) => callable.signature(self.db).display(self.db).fmt(f), Type::BoundMethod(bound_method) => { + let function = bound_method.function(self.db); + let self_instance = bound_method.self_instance(self.db); + let self_instance_specialization = match self_instance { + Type::Instance(InstanceType { + class: ClassType::Generic(alias), + }) => Some(alias.specialization(self.db)), + _ => None, + }; + let specialization = match function.specialization(self.db) { + Some(specialization) + if self_instance_specialization.is_none_or(|sis| specialization == sis) => + { + specialization.display_short(self.db).to_string() + } + _ => String::new(), + }; write!( f, - "", - method = bound_method.function(self.db).name(self.db), - instance = bound_method.self_instance(self.db).display(self.db) + "", + method = function.name(self.db), + instance = bound_method.self_instance(self.db).display(self.db), ) } Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { write!( f, - "", - function = function.name(self.db) + "", + function = function.name(self.db), + specialization = if let Some(specialization) = function.specialization(self.db) + { + specialization.display_short(self.db).to_string() + } else { + String::new() + }, ) } Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(_)) => { @@ -174,6 +206,86 @@ impl Display for DisplayRepresentation<'_> { } } +impl<'db> GenericAlias<'db> { + pub(crate) fn display(&'db self, db: &'db dyn Db) -> DisplayGenericAlias<'db> { + DisplayGenericAlias { + origin: self.origin(db), + specialization: self.specialization(db), + db, + } + } +} + +pub(crate) struct DisplayGenericAlias<'db> { + origin: GenericClass<'db>, + specialization: Specialization<'db>, + db: &'db dyn Db, +} + +impl Display for DisplayGenericAlias<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "{origin}{specialization}", + origin = self.origin.class(self.db).name, + specialization = self.specialization.display_short(self.db), + ) + } +} + +impl<'db> Specialization<'db> { + /// Renders the specialization in full, e.g. `{T = int, U = str}`. + pub fn display(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> { + DisplaySpecialization { + typevars: self.generic_context(db).variables(db), + types: self.types(db), + db, + full: true, + } + } + + /// Renders the specialization as it would appear in a subscript expression, e.g. `[int, str]`. + pub fn display_short(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> { + DisplaySpecialization { + typevars: self.generic_context(db).variables(db), + types: self.types(db), + db, + full: false, + } + } +} + +pub struct DisplaySpecialization<'db> { + typevars: &'db [TypeVarInstance<'db>], + types: &'db [Type<'db>], + db: &'db dyn Db, + full: bool, +} + +impl Display for DisplaySpecialization<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.full { + f.write_char('{')?; + for (idx, (var, ty)) in self.typevars.iter().zip(self.types).enumerate() { + if idx > 0 { + f.write_str(", ")?; + } + write!(f, "{} = {}", var.name(self.db), ty.display(self.db))?; + } + f.write_char('}') + } else { + f.write_char('[')?; + for (idx, (_, ty)) in self.typevars.iter().zip(self.types).enumerate() { + if idx > 0 { + f.write_str(", ")?; + } + write!(f, "{}", ty.display(self.db))?; + } + f.write_char(']') + } + } +} + impl<'db> Signature<'db> { fn display(&'db self, db: &'db dyn Db) -> DisplaySignature<'db> { DisplaySignature { diff --git a/crates/red_knot_python_semantic/src/types/generics.rs b/crates/red_knot_python_semantic/src/types/generics.rs new file mode 100644 index 0000000000..977808c84b --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/generics.rs @@ -0,0 +1,148 @@ +use ruff_python_ast as ast; + +use crate::semantic_index::SemanticIndex; +use crate::types::signatures::{Parameter, Parameters, Signature}; +use crate::types::{ + declaration_type, KnownInstanceType, Type, TypeVarBoundOrConstraints, TypeVarInstance, + UnionType, +}; +use crate::Db; + +/// A list of formal type variables for a generic function, class, or type alias. +#[salsa::tracked(debug)] +pub struct GenericContext<'db> { + #[return_ref] + pub(crate) variables: Box<[TypeVarInstance<'db>]>, +} + +impl<'db> GenericContext<'db> { + pub(crate) fn from_type_params( + db: &'db dyn Db, + index: &'db SemanticIndex<'db>, + type_params_node: &ast::TypeParams, + ) -> Self { + let variables = type_params_node + .iter() + .filter_map(|type_param| Self::variable_from_type_param(db, index, type_param)) + .collect(); + Self::new(db, variables) + } + + fn variable_from_type_param( + db: &'db dyn Db, + index: &'db SemanticIndex<'db>, + type_param_node: &ast::TypeParam, + ) -> Option> { + match type_param_node { + ast::TypeParam::TypeVar(node) => { + let definition = index.expect_single_definition(node); + let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) = + declaration_type(db, definition).inner_type() + else { + panic!("typevar should be inferred as a TypeVarInstance"); + }; + Some(typevar) + } + // TODO: Support these! + ast::TypeParam::ParamSpec(_) => None, + ast::TypeParam::TypeVarTuple(_) => None, + } + } + + pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { + let parameters = Parameters::new( + self.variables(db) + .iter() + .map(|typevar| Self::parameter_from_typevar(db, *typevar)), + ); + Signature::new(parameters, None) + } + + fn parameter_from_typevar(db: &'db dyn Db, typevar: TypeVarInstance<'db>) -> Parameter<'db> { + let mut parameter = Parameter::positional_only(Some(typevar.name(db).clone())); + match typevar.bound_or_constraints(db) { + Some(TypeVarBoundOrConstraints::UpperBound(bound)) => { + // TODO: This should be a type form. + parameter = parameter.with_annotated_type(bound); + } + Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { + // TODO: This should be a new type variant where only these exact types are + // assignable, and not subclasses of them, nor a union of them. + parameter = parameter + .with_annotated_type(UnionType::from_elements(db, constraints.iter(db))); + } + None => {} + } + parameter + } + + pub(crate) fn default_specialization(self, db: &'db dyn Db) -> Specialization<'db> { + let types = self + .variables(db) + .iter() + .map(|typevar| typevar.default_ty(db).unwrap_or(Type::unknown())) + .collect(); + self.specialize(db, types) + } + + pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> { + let types = vec![Type::unknown(); self.variables(db).len()]; + self.specialize(db, types.into()) + } + + pub(crate) fn specialize( + self, + db: &'db dyn Db, + types: Box<[Type<'db>]>, + ) -> Specialization<'db> { + Specialization::new(db, self, types) + } +} + +/// An assignment of a specific type to each type variable in a generic scope. +#[salsa::tracked(debug)] +pub struct Specialization<'db> { + pub(crate) generic_context: GenericContext<'db>, + #[return_ref] + pub(crate) types: Box<[Type<'db>]>, +} + +impl<'db> Specialization<'db> { + /// Applies a specialization to this specialization. This is used, for instance, when a generic + /// class inherits from a generic alias: + /// + /// ```py + /// class A[T]: ... + /// class B[U](A[U]): ... + /// ``` + /// + /// `B` is a generic class, whose MRO includes the generic alias `A[U]`, which specializes `A` + /// with the specialization `{T: U}`. If `B` is specialized to `B[int]`, with specialization + /// `{U: int}`, we can apply the second specialization to the first, resulting in `T: int`. + /// That lets us produce the generic alias `A[int]`, which is the corresponding entry in the + /// MRO of `B[int]`. + pub(crate) fn apply_specialization(self, db: &'db dyn Db, other: Specialization<'db>) -> Self { + let types = self + .types(db) + .into_iter() + .map(|ty| ty.apply_specialization(db, other)) + .collect(); + Specialization::new(db, self.generic_context(db), types) + } + + pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { + let types = self.types(db).iter().map(|ty| ty.normalized(db)).collect(); + Self::new(db, self.generic_context(db), types) + } + + /// Returns the type that a typevar is specialized to, or None if the typevar isn't part of + /// this specialization. + pub(crate) fn get(self, db: &'db dyn Db, typevar: TypeVarInstance<'db>) -> Option> { + self.generic_context(db) + .variables(db) + .into_iter() + .zip(self.types(db)) + .find(|(var, _)| **var == typevar) + .map(|(_, ty)| *ty) + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index bf38e1ae70..bee37ffefa 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -37,6 +37,7 @@ use itertools::{Either, Itertools}; use ruff_db::diagnostic::{DiagnosticId, Severity}; use ruff_db::files::File; use ruff_db::parsed::parsed_module; +use ruff_python_ast::visitor::{walk_expr, Visitor}; use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext}; use ruff_text_size::{Ranged, TextRange}; use rustc_hash::{FxHashMap, FxHashSet}; @@ -61,6 +62,7 @@ use crate::symbol::{ typing_extensions_symbol, Boundness, LookupError, }; use crate::types::call::{Argument, Bindings, CallArgumentTypes, CallArguments, CallError}; +use crate::types::class::MetaclassErrorKind; use crate::types::diagnostic::{ report_implicit_return_type, report_invalid_arguments_to_annotated, report_invalid_arguments_to_callable, report_invalid_assignment, @@ -73,18 +75,18 @@ use crate::types::diagnostic::{ INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_IMPORT, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR, }; +use crate::types::generics::GenericContext; use crate::types::mro::MroErrorKind; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - class::MetaclassErrorKind, todo_type, Class, DynamicType, FunctionType, IntersectionBuilder, - IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MetaclassCandidate, Parameter, - ParameterForm, Parameters, SliceLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, + todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType, + DynamicType, FunctionDecorators, FunctionType, GenericAlias, GenericClass, IntersectionBuilder, + IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MemberLookupPolicy, + MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, Signature, + Signatures, SliceLiteralType, StringLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, UnionType, }; -use crate::types::{ - CallableType, FunctionDecorators, MemberLookupPolicy, Signature, StringLiteralType, -}; use crate::unpack::{Unpack, UnpackPosition}; use crate::util::subscript::{PyIndex, PySlice}; use crate::Db; @@ -102,7 +104,6 @@ use super::slots::check_class_slots; use super::string_annotation::{ parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, }; -use super::{CallDunderError, ClassLiteralType}; /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the @@ -708,7 +709,7 @@ impl<'db> TypeInferenceBuilder<'db> { if let DefinitionKind::Class(class) = definition.kind(self.db()) { ty.inner_type() .into_class_literal() - .map(|ty| (ty.class(), class.node())) + .map(|ty| (ty, class.node())) } else { None } @@ -736,10 +737,7 @@ impl<'db> TypeInferenceBuilder<'db> { // (2) Check for classes that inherit from `@final` classes for (i, base_class) in class.explicit_bases(self.db()).iter().enumerate() { // dynamic/unknown bases are never `@final` - let Some(base_class) = base_class - .into_class_literal() - .map(super::class::ClassLiteralType::class) - else { + let Some(base_class) = base_class.into_class_literal() else { continue; }; if !base_class.is_final(self.db()) { @@ -757,7 +755,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // (3) Check that the class's MRO is resolvable - match class.try_mro(self.db()).as_ref() { + match class.try_mro(self.db(), None).as_ref() { Err(mro_error) => { match mro_error.reason() { MroErrorKind::DuplicateBases(duplicates) => { @@ -983,7 +981,7 @@ impl<'db> TypeInferenceBuilder<'db> { Type::BooleanLiteral(_) | Type::IntLiteral(_) => {} Type::Instance(instance) if matches!( - instance.class().known(self.db()), + instance.class.known(self.db()), Some(KnownClass::Float | KnownClass::Int | KnownClass::Bool) ) => {} _ => return false, @@ -1415,7 +1413,7 @@ impl<'db> TypeInferenceBuilder<'db> { continue; } } else if let Type::ClassLiteral(class) = decorator_ty { - if class.class.is_known(self.db(), KnownClass::Classmethod) { + if class.is_known(self.db(), KnownClass::Classmethod) { function_decorators |= FunctionDecorators::CLASSMETHOD; continue; } @@ -1453,12 +1451,15 @@ impl<'db> TypeInferenceBuilder<'db> { .node_scope(NodeWithScopeRef::Function(function)) .to_scope_id(self.db(), self.file()); + let specialization = None; + let mut inferred_ty = Type::FunctionLiteral(FunctionType::new( self.db(), &name.id, function_kind, body_scope, function_decorators, + specialization, )); for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { @@ -1695,6 +1696,10 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_decorator(decorator); } + let generic_context = type_params.as_ref().map(|type_params| { + GenericContext::from_type_params(self.db(), self.index, type_params) + }); + let body_scope = self .index .node_scope(NodeWithScopeRef::Class(class_node)) @@ -1702,8 +1707,18 @@ impl<'db> TypeInferenceBuilder<'db> { let maybe_known_class = KnownClass::try_from_file_and_name(self.db(), self.file(), name); - let class = Class::new(self.db(), &name.id, body_scope, maybe_known_class); - let class_ty = Type::class_literal(class); + let class = Class { + name: name.id.clone(), + body_scope, + known: maybe_known_class, + }; + let class_literal = match generic_context { + Some(generic_context) => { + ClassLiteralType::Generic(GenericClass::new(self.db(), class, generic_context)) + } + None => ClassLiteralType::NonGeneric(NonGenericClass::new(self.db(), class)), + }; + let class_ty = Type::from(class_literal); self.add_declaration_with_binding( class_node.into(), @@ -1719,8 +1734,12 @@ impl<'db> TypeInferenceBuilder<'db> { } // Inference of bases deferred in stubs - // TODO also defer stringified generic type parameters - if self.are_all_types_deferred() { + // TODO: Only defer the references that are actually string literals, instead of + // deferring the entire class definition if a string literal occurs anywhere in the + // base class list. + if self.are_all_types_deferred() + || class_node.bases().iter().any(contains_string_literal) + { self.types.deferred.insert(definition); } else { for base in class_node.bases() { @@ -2532,7 +2551,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - Type::ClassLiteral(..) | Type::SubclassOf(..) => { + Type::ClassLiteral(..) | Type::GenericAlias(..) | Type::SubclassOf(..) => { match object_ty.class_member(db, attribute.into()) { SymbolAndQualifiers { symbol: Symbol::Type(meta_attr_ty, meta_attr_boundness), @@ -2856,10 +2875,7 @@ impl<'db> TypeInferenceBuilder<'db> { // Handle various singletons. if let Type::Instance(instance) = declared_ty.inner_type() { - if instance - .class() - .is_known(self.db(), KnownClass::SpecialForm) - { + if instance.class.is_known(self.db(), KnownClass::SpecialForm) { if let Some(name_expr) = target.as_name_expr() { if let Some(known_instance) = KnownInstanceType::try_from_file_and_name( self.db(), @@ -4018,9 +4034,12 @@ impl<'db> TypeInferenceBuilder<'db> { let class = match callable_type { Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() { ClassBase::Dynamic(_) => None, - ClassBase::Class(class) => Some(class), + ClassBase::Class(class) => { + let (class_literal, _) = class.class_literal(self.db()); + Some(class_literal) + } }, - Type::ClassLiteral(ClassLiteralType { class }) => Some(class), + Type::ClassLiteral(class) => Some(class), _ => None, }; @@ -4475,7 +4494,7 @@ impl<'db> TypeInferenceBuilder<'db> { LookupError::Unbound(_) => { let bound_on_instance = match value_type { Type::ClassLiteral(class) => { - !class.class().instance_member(db, attr).symbol.is_unbound() + !class.instance_member(db, None, attr).symbol.is_unbound() } Type::SubclassOf(subclass_of @ SubclassOfType { .. }) => { match subclass_of.subclass_of() { @@ -4588,6 +4607,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::BoundMethod(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) + | Type::GenericAlias(_) | Type::SubclassOf(_) | Type::Instance(_) | Type::KnownInstance(_) @@ -4864,6 +4884,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::MethodWrapper(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) + | Type::GenericAlias(_) | Type::SubclassOf(_) | Type::Instance(_) | Type::KnownInstance(_) @@ -4885,6 +4906,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::MethodWrapper(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) + | Type::GenericAlias(_) | Type::SubclassOf(_) | Type::Instance(_) | Type::KnownInstance(_) @@ -5452,9 +5474,7 @@ impl<'db> TypeInferenceBuilder<'db> { range, ), (Type::Tuple(_), Type::Instance(instance)) - if instance - .class() - .is_known(self.db(), KnownClass::VersionInfo) => + if instance.class.is_known(self.db(), KnownClass::VersionInfo) => { self.infer_binary_type_comparison( left, @@ -5464,9 +5484,7 @@ impl<'db> TypeInferenceBuilder<'db> { ) } (Type::Instance(instance), Type::Tuple(_)) - if instance - .class() - .is_known(self.db(), KnownClass::VersionInfo) => + if instance.class.is_known(self.db(), KnownClass::VersionInfo) => { self.infer_binary_type_comparison( Type::version_info_tuple(self.db()), @@ -5774,11 +5792,71 @@ impl<'db> TypeInferenceBuilder<'db> { ctx: _, } = subscript; + // HACK ALERT: If we are subscripting a generic class, short-circuit the rest of the + // subscript inference logic and treat this as an explicit specialization. + // TODO: Move this logic into a custom callable, and update `find_name_in_mro` to return + // this callable as the `__class_getitem__` method on `type`. That probably requires + // updating all of the subscript logic below to use custom callables for all of the _other_ + // special cases, too. let value_ty = self.infer_expression(value); + if let Type::ClassLiteral(ClassLiteralType::Generic(generic_class)) = value_ty { + return self.infer_explicit_class_specialization( + subscript, + value_ty, + generic_class, + slice, + ); + } + let slice_ty = self.infer_expression(slice); self.infer_subscript_expression_types(value, value_ty, slice_ty) } + fn infer_explicit_class_specialization( + &mut self, + subscript: &ast::ExprSubscript, + value_ty: Type<'db>, + generic_class: GenericClass<'db>, + slice_node: &ast::Expr, + ) -> Type<'db> { + let mut call_argument_types = match slice_node { + ast::Expr::Tuple(tuple) => CallArgumentTypes::positional( + tuple.elts.iter().map(|elt| self.infer_type_expression(elt)), + ), + _ => CallArgumentTypes::positional([self.infer_type_expression(slice_node)]), + }; + let generic_context = generic_class.generic_context(self.db()); + let signatures = Signatures::single(CallableSignature::single( + value_ty, + generic_context.signature(self.db()), + )); + let bindings = match Bindings::match_parameters(signatures, &mut call_argument_types) + .check_types(self.db(), &mut call_argument_types) + { + Ok(bindings) => bindings, + Err(CallError(_, bindings)) => { + bindings.report_diagnostics(&self.context, subscript.into()); + return Type::unknown(); + } + }; + let callable = bindings + .into_iter() + .next() + .expect("valid bindings should have one callable"); + let (_, overload) = callable + .matching_overload() + .expect("valid bindings should have matching overload"); + let specialization = generic_context.specialize( + self.db(), + overload + .parameter_types() + .iter() + .map(|ty| ty.unwrap_or(Type::unknown())) + .collect(), + ); + Type::from(GenericAlias::new(self.db(), generic_class, specialization)) + } + fn infer_subscript_expression_types( &mut self, value_node: &ast::Expr, @@ -5789,16 +5867,12 @@ impl<'db> TypeInferenceBuilder<'db> { ( Type::Instance(instance), Type::IntLiteral(_) | Type::BooleanLiteral(_) | Type::SliceLiteral(_), - ) if instance - .class() - .is_known(self.db(), KnownClass::VersionInfo) => - { - self.infer_subscript_expression_types( + ) if instance.class.is_known(self.db(), KnownClass::VersionInfo) => self + .infer_subscript_expression_types( value_node, Type::version_info_tuple(self.db()), slice_ty, - ) - } + ), // Ex) Given `("a", "b", "c", "d")[1]`, return `"b"` (Type::Tuple(tuple_ty), Type::IntLiteral(int)) if i32::try_from(int).is_ok() => { @@ -6006,9 +6080,16 @@ impl<'db> TypeInferenceBuilder<'db> { } } - if matches!(value_ty, Type::ClassLiteral(class_literal) if class_literal.class().is_known(self.db(), KnownClass::Type)) - { - return KnownClass::GenericAlias.to_instance(self.db()); + if let Type::ClassLiteral(class) = value_ty { + if class.is_known(self.db(), KnownClass::Type) { + return KnownClass::GenericAlias.to_instance(self.db()); + } + + if let ClassLiteralType::Generic(_) = class { + // TODO: specialize the generic class using these explicit type + // variable assignments + return value_ty; + } } report_non_subscriptable( @@ -6031,6 +6112,10 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: proper support for generic classes // For now, just infer `Sequence`, if we see something like `Sequence[str]`. This allows us // to look up attributes on generic base classes, even if we don't understand generics yet. + // Note that this isn't handled by the clause up above for generic classes + // that use legacy type variables and an explicit `Generic` base class. + // Once we handle legacy typevars, this special case will be removed in + // favor of the specialization logic above. value_ty } _ => Type::unknown(), @@ -6063,7 +6148,7 @@ impl<'db> TypeInferenceBuilder<'db> { }, Some(Type::BooleanLiteral(b)) => SliceArg::Arg(Some(i32::from(b))), Some(Type::Instance(instance)) - if instance.class().is_known(self.db(), KnownClass::NoneType) => + if instance.class.is_known(self.db(), KnownClass::NoneType) => { SliceArg::Arg(None) } @@ -6629,8 +6714,7 @@ impl<'db> TypeInferenceBuilder<'db> { value_ty: Type<'db>, ) -> Type<'db> { match value_ty { - Type::ClassLiteral(class_literal_ty) => match class_literal_ty.class().known(self.db()) - { + Type::ClassLiteral(class_literal) => match class_literal.known(self.db()) { Some(KnownClass::Tuple) => self.infer_tuple_type_expression(slice), Some(KnownClass::Type) => self.infer_subclass_of_type_expression(slice), _ => self.infer_subscript_type_expression(subscript, value_ty), @@ -6732,14 +6816,14 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::Name(_) | ast::Expr::Attribute(_) => { let name_ty = self.infer_expression(slice); match name_ty { - Type::ClassLiteral(class_literal_ty) => { - if class_literal_ty - .class() - .is_known(self.db(), KnownClass::Any) - { + Type::ClassLiteral(class_literal) => { + if class_literal.is_known(self.db(), KnownClass::Any) { SubclassOfType::subclass_of_any() } else { - SubclassOfType::from(self.db(), class_literal_ty.class()) + SubclassOfType::from( + self.db(), + class_literal.default_specialization(self.db()), + ) } } Type::KnownInstance(KnownInstanceType::Any) => { @@ -6820,7 +6904,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = subscript; match value_ty { - Type::ClassLiteral(literal) if literal.class().is_known(self.db(), KnownClass::Any) => { + Type::ClassLiteral(literal) if literal.is_known(self.db(), KnownClass::Any) => { self.context.report_lint( &INVALID_TYPE_FORM, subscript, @@ -7516,6 +7600,21 @@ impl StringPartsCollector { } } +fn contains_string_literal(expr: &ast::Expr) -> bool { + struct ContainsStringLiteral(bool); + + impl<'a> Visitor<'a> for ContainsStringLiteral { + fn visit_expr(&mut self, expr: &'a ast::Expr) { + self.0 |= matches!(expr, ast::Expr::StringLiteral(_)); + walk_expr(self, expr); + } + } + + let mut visitor = ContainsStringLiteral(false); + visitor.visit_expr(expr); + visitor.0 +} + #[cfg(test)] mod tests { use crate::db::tests::{setup_db, TestDb}; diff --git a/crates/red_knot_python_semantic/src/types/mro.rs b/crates/red_knot_python_semantic/src/types/mro.rs index b87361e9db..f1d945cae8 100644 --- a/crates/red_knot_python_semantic/src/types/mro.rs +++ b/crates/red_knot_python_semantic/src/types/mro.rs @@ -4,34 +4,58 @@ use std::ops::Deref; use rustc_hash::FxHashSet; use crate::types::class_base::ClassBase; -use crate::types::{Class, Type}; +use crate::types::generics::Specialization; +use crate::types::{ClassLiteralType, ClassType, Type}; use crate::Db; /// The inferred method resolution order of a given class. /// -/// See [`Class::iter_mro`] for more details. +/// An MRO cannot contain non-specialized generic classes. (This is why [`ClassBase`] contains a +/// [`ClassType`], not a [`ClassLiteralType`].) Any generic classes in a base class list are always +/// specialized — either because the class is explicitly specialized if there is a subscript +/// expression, or because we create the default specialization if there isn't. +/// +/// The MRO of a non-specialized generic class can contain generic classes that are specialized +/// with a typevar from the inheriting class. When the inheriting class is specialized, the MRO of +/// the resulting generic alias will substitute those type variables accordingly. For instance, in +/// the following example, the MRO of `D[int]` includes `C[int]`, and the MRO of `D[U]` includes +/// `C[U]` (which is a generic alias, not a non-specialized generic class): +/// +/// ```py +/// class C[T]: ... +/// class D[U](C[U]): ... +/// ``` +/// +/// See [`ClassType::iter_mro`] for more details. #[derive(PartialEq, Eq, Clone, Debug, salsa::Update)] pub(super) struct Mro<'db>(Box<[ClassBase<'db>]>); impl<'db> Mro<'db> { - /// Attempt to resolve the MRO of a given class + /// Attempt to resolve the MRO of a given class. Because we derive the MRO from the list of + /// base classes in the class definition, this operation is performed on a [class + /// literal][ClassLiteralType], not a [class type][ClassType]. (You can _also_ get the MRO of a + /// class type, but this is done by first getting the MRO of the underlying class literal, and + /// specializing each base class as needed if the class type is a generic alias.) /// - /// In the event that a possible list of bases would (or could) lead to a - /// `TypeError` being raised at runtime due to an unresolvable MRO, we infer - /// the MRO of the class as being `[, Unknown, object]`. - /// This seems most likely to reduce the possibility of cascading errors - /// elsewhere. + /// In the event that a possible list of bases would (or could) lead to a `TypeError` being + /// raised at runtime due to an unresolvable MRO, we infer the MRO of the class as being `[, Unknown, object]`. This seems most likely to reduce the possibility of + /// cascading errors elsewhere. (For a generic class, the first entry in this fallback MRO uses + /// the default specialization of the class's type variables.) /// /// (We emit a diagnostic warning about the runtime `TypeError` in /// [`super::infer::TypeInferenceBuilder::infer_region_scope`].) - pub(super) fn of_class(db: &'db dyn Db, class: Class<'db>) -> Result> { - Self::of_class_impl(db, class).map_err(|error_kind| MroError { - kind: error_kind, - fallback_mro: Self::from_error(db, class), + pub(super) fn of_class( + db: &'db dyn Db, + class: ClassLiteralType<'db>, + specialization: Option>, + ) -> Result> { + Self::of_class_impl(db, class, specialization).map_err(|err| { + err.into_mro_error(db, class.apply_optional_specialization(db, specialization)) }) } - pub(super) fn from_error(db: &'db dyn Db, class: Class<'db>) -> Self { + pub(super) fn from_error(db: &'db dyn Db, class: ClassType<'db>) -> Self { Self::from([ ClassBase::Class(class), ClassBase::unknown(), @@ -39,20 +63,30 @@ impl<'db> Mro<'db> { ]) } - fn of_class_impl(db: &'db dyn Db, class: Class<'db>) -> Result> { + fn of_class_impl( + db: &'db dyn Db, + class: ClassLiteralType<'db>, + specialization: Option>, + ) -> Result> { let class_bases = class.explicit_bases(db); if !class_bases.is_empty() && class.inheritance_cycle(db).is_some() { // We emit errors for cyclically defined classes elsewhere. // It's important that we don't even try to infer the MRO for a cyclically defined class, // or we'll end up in an infinite loop. - return Ok(Mro::from_error(db, class)); + return Ok(Mro::from_error( + db, + class.apply_optional_specialization(db, specialization), + )); } match class_bases { // `builtins.object` is the special case: // the only class in Python that has an MRO with length <2 - [] if class.is_object(db) => Ok(Self::from([ClassBase::Class(class)])), + [] if class.is_object(db) => Ok(Self::from([ + // object is not generic, so the default specialization should be a no-op + ClassBase::Class(class.apply_optional_specialization(db, specialization)), + ])), // All other classes in Python have an MRO with length >=2. // Even if a class has no explicit base classes, @@ -67,7 +101,10 @@ impl<'db> Mro<'db> { // >>> Foo.__mro__ // (, ) // ``` - [] => Ok(Self::from([ClassBase::Class(class), ClassBase::object(db)])), + [] => Ok(Self::from([ + ClassBase::Class(class.apply_optional_specialization(db, specialization)), + ClassBase::object(db), + ])), // Fast path for a class that has only a single explicit base. // @@ -77,9 +114,11 @@ impl<'db> Mro<'db> { [single_base] => ClassBase::try_from_type(db, *single_base).map_or_else( || Err(MroErrorKind::InvalidBases(Box::from([(0, *single_base)]))), |single_base| { - Ok(std::iter::once(ClassBase::Class(class)) - .chain(single_base.mro(db)) - .collect()) + Ok(std::iter::once(ClassBase::Class( + class.apply_optional_specialization(db, specialization), + )) + .chain(single_base.mro(db)) + .collect()) }, ), @@ -103,7 +142,9 @@ impl<'db> Mro<'db> { return Err(MroErrorKind::InvalidBases(invalid_bases.into_boxed_slice())); } - let mut seqs = vec![VecDeque::from([ClassBase::Class(class)])]; + let mut seqs = vec![VecDeque::from([ClassBase::Class( + class.apply_optional_specialization(db, specialization), + )])]; for base in &valid_bases { seqs.push(base.mro(db).collect()); } @@ -118,7 +159,8 @@ impl<'db> Mro<'db> { .filter_map(|(index, base)| Some((index, base.into_class()?))) { if !seen_bases.insert(base) { - duplicate_bases.push((index, base)); + let (base_class_literal, _) = base.class_literal(db); + duplicate_bases.push((index, base_class_literal)); } } @@ -178,12 +220,15 @@ impl<'db> FromIterator> for Mro<'db> { /// /// Even for first-party code, where we will have to resolve the MRO for every class we encounter, /// loading the cached MRO comes with a certain amount of overhead, so it's best to avoid calling the -/// Salsa-tracked [`Class::try_mro`] method unless it's absolutely necessary. +/// Salsa-tracked [`ClassLiteralType::try_mro`] method unless it's absolutely necessary. pub(super) struct MroIterator<'db> { db: &'db dyn Db, /// The class whose MRO we're iterating over - class: Class<'db>, + class: ClassLiteralType<'db>, + + /// The specialization to apply to each MRO element, if any + specialization: Option>, /// Whether or not we've already yielded the first element of the MRO first_element_yielded: bool, @@ -197,10 +242,15 @@ pub(super) struct MroIterator<'db> { } impl<'db> MroIterator<'db> { - pub(super) fn new(db: &'db dyn Db, class: Class<'db>) -> Self { + pub(super) fn new( + db: &'db dyn Db, + class: ClassLiteralType<'db>, + specialization: Option>, + ) -> Self { Self { db, class, + specialization, first_element_yielded: false, subsequent_elements: None, } @@ -211,7 +261,7 @@ impl<'db> MroIterator<'db> { fn full_mro_except_first_element(&mut self) -> impl Iterator> + '_ { self.subsequent_elements .get_or_insert_with(|| { - let mut full_mro_iter = match self.class.try_mro(self.db) { + let mut full_mro_iter = match self.class.try_mro(self.db, self.specialization) { Ok(mro) => mro.iter(), Err(error) => error.fallback_mro().iter(), }; @@ -228,7 +278,10 @@ impl<'db> Iterator for MroIterator<'db> { fn next(&mut self) -> Option { if !self.first_element_yielded { self.first_element_yielded = true; - return Some(ClassBase::Class(self.class)); + return Some(ClassBase::Class( + self.class + .apply_optional_specialization(self.db, self.specialization), + )); } self.full_mro_except_first_element().next() } @@ -273,11 +326,11 @@ pub(super) enum MroErrorKind<'db> { /// The class has one or more duplicate bases. /// - /// This variant records the indices and [`Class`]es + /// This variant records the indices and [`ClassLiteralType`]s /// of the duplicate bases. The indices are the indices of nodes /// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node. /// Each index is the index of a node representing a duplicate base. - DuplicateBases(Box<[(usize, Class<'db>)]>), + DuplicateBases(Box<[(usize, ClassLiteralType<'db>)]>), /// The MRO is otherwise unresolvable through the C3-merge algorithm. /// @@ -285,6 +338,15 @@ pub(super) enum MroErrorKind<'db> { UnresolvableMro { bases_list: Box<[ClassBase<'db>]> }, } +impl<'db> MroErrorKind<'db> { + pub(super) fn into_mro_error(self, db: &'db dyn Db, class: ClassType<'db>) -> MroError<'db> { + MroError { + kind: self, + fallback_mro: Mro::from_error(db, class), + } + } +} + /// Implementation of the [C3-merge algorithm] for calculating a Python class's /// [method resolution order]. /// diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 4ba8e4845f..cf5431b47e 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -159,10 +159,10 @@ impl KnownConstraintFunction { Type::ClassLiteral(class_literal) => { // At runtime (on Python 3.11+), this will return `True` for classes that actually // do inherit `typing.Any` and `False` otherwise. We could accurately model that? - if class_literal.class().is_known(db, KnownClass::Any) { + if class_literal.is_known(db, KnownClass::Any) { None } else { - Some(constraint_fn(class_literal.class())) + Some(constraint_fn(class_literal.default_specialization(db))) } } Type::SubclassOf(subclass_of_ty) => { @@ -473,8 +473,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { range: _, }, }) if keywords.is_empty() => { - let Type::ClassLiteral(ClassLiteralType { class: rhs_class }) = rhs_ty else { - continue; + let rhs_class = match rhs_ty { + Type::ClassLiteral(class) => class, + Type::GenericAlias(alias) => { + ClassLiteralType::Generic(alias.origin(self.db)) + } + _ => { + continue; + } }; let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else { @@ -496,10 +502,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> { if callable_type .into_class_literal() - .is_some_and(|c| c.class().is_known(self.db, KnownClass::Type)) + .is_some_and(|c| c.is_known(self.db, KnownClass::Type)) { let symbol = self.expect_expr_name_symbol(id); - constraints.insert(symbol, Type::instance(rhs_class)); + constraints.insert( + symbol, + Type::instance(rhs_class.unknown_specialization(self.db)), + ); } } _ => {} @@ -550,7 +559,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> { Type::ClassLiteral(class_type) if expr_call.arguments.args.len() == 1 && expr_call.arguments.keywords.is_empty() - && class_type.class().is_known(self.db, KnownClass::Bool) => + && class_type.is_known(self.db, KnownClass::Bool) => { self.evaluate_expression_node_predicate( &expr_call.arguments.args[0], diff --git a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs index 4dcc6b6c08..8072190159 100644 --- a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs +++ b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs @@ -11,6 +11,8 @@ use ruff_python_ast::name::Name; /// A test representation of a type that can be transformed unambiguously into a real Type, /// given a db. +/// +/// TODO: We should add some variants that exercise generic classes and specializations thereof. #[derive(Debug, Clone, PartialEq)] pub(crate) enum Ty { Never, @@ -167,7 +169,7 @@ impl Ty { .symbol .expect_type() .expect_class_literal() - .class, + .default_specialization(db), ), Ty::SubclassOfAbcClass(s) => SubclassOfType::from( db, @@ -175,7 +177,7 @@ impl Ty { .symbol .expect_type() .expect_class_literal() - .class, + .default_specialization(db), ), Ty::AlwaysTruthy => Type::AlwaysTruthy, Ty::AlwaysFalsy => Type::AlwaysFalsy, diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index f67fa7c79c..6a1091398f 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -14,6 +14,7 @@ use smallvec::{smallvec, SmallVec}; use super::{definition_expression_type, DynamicType, Type}; use crate::semantic_index::definition::Definition; +use crate::types::generics::Specialization; use crate::types::todo_type; use crate::Db; use ruff_python_ast::{self as ast, name::Name}; @@ -261,6 +262,17 @@ impl<'db> Signature<'db> { } } + pub(crate) fn apply_specialization( + &mut self, + db: &'db dyn Db, + specialization: Specialization<'db>, + ) { + self.parameters.apply_specialization(db, specialization); + self.return_ty = self + .return_ty + .map(|ty| ty.apply_specialization(db, specialization)); + } + /// Return the parameters in this signature. pub(crate) fn parameters(&self) -> &Parameters<'db> { &self.parameters @@ -445,6 +457,12 @@ impl<'db> Parameters<'db> { ) } + fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { + self.value + .iter_mut() + .for_each(|param| param.apply_specialization(db, specialization)); + } + pub(crate) fn len(&self) -> usize { self.value.len() } @@ -606,6 +624,13 @@ impl<'db> Parameter<'db> { self } + fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { + self.annotated_type = self + .annotated_type + .map(|ty| ty.apply_specialization(db, specialization)); + self.kind.apply_specialization(db, specialization); + } + /// Strip information from the parameter so that two equivalent parameters compare equal. /// Normalize nested unions and intersections in the annotated type, if any. /// @@ -792,6 +817,19 @@ pub(crate) enum ParameterKind<'db> { }, } +impl<'db> ParameterKind<'db> { + fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { + match self { + Self::PositionalOnly { default_type, .. } + | Self::PositionalOrKeyword { default_type, .. } + | Self::KeywordOnly { default_type, .. } => { + *default_type = default_type.map(|ty| ty.apply_specialization(db, specialization)); + } + Self::Variadic { .. } | Self::KeywordVariadic { .. } => {} + } + } +} + /// Whether a parameter is used as a value or a type form. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub(crate) enum ParameterForm { diff --git a/crates/red_knot_python_semantic/src/types/slots.rs b/crates/red_knot_python_semantic/src/types/slots.rs index b0afda81ba..2b265372ad 100644 --- a/crates/red_knot_python_semantic/src/types/slots.rs +++ b/crates/red_knot_python_semantic/src/types/slots.rs @@ -4,7 +4,7 @@ use crate::db::Db; use crate::symbol::{Boundness, Symbol}; use crate::types::class_base::ClassBase; use crate::types::diagnostic::report_base_with_incompatible_slots; -use crate::types::{Class, ClassLiteralType, Type}; +use crate::types::{ClassLiteralType, Type}; use super::InferContext; @@ -23,7 +23,7 @@ enum SlotsKind { } impl SlotsKind { - fn from(db: &dyn Db, base: Class) -> Self { + fn from(db: &dyn Db, base: ClassLiteralType) -> Self { let Symbol::Type(slots_ty, bound) = base.own_class_member(db, "__slots__").symbol else { return Self::NotSpecified; }; @@ -50,7 +50,11 @@ impl SlotsKind { } } -pub(super) fn check_class_slots(context: &InferContext, class: Class, node: &ast::StmtClassDef) { +pub(super) fn check_class_slots( + context: &InferContext, + class: ClassLiteralType, + node: &ast::StmtClassDef, +) { let db = context.db(); let mut first_with_solid_base = None; @@ -58,16 +62,17 @@ pub(super) fn check_class_slots(context: &InferContext, class: Class, node: &ast let mut found_second = false; for (index, base) in class.explicit_bases(db).iter().enumerate() { - let Type::ClassLiteral(ClassLiteralType { class: base }) = base else { + let Type::ClassLiteral(base) = base else { continue; }; - let solid_base = base.iter_mro(db).find_map(|current| { + let solid_base = base.iter_mro(db, None).find_map(|current| { let ClassBase::Class(current) = current else { return None; }; - match SlotsKind::from(db, current) { + let (class_literal, _) = current.class_literal(db); + match SlotsKind::from(db, class_literal) { SlotsKind::NotEmpty => Some(current), SlotsKind::NotSpecified | SlotsKind::Empty => None, SlotsKind::Dynamic => None, diff --git a/crates/red_knot_python_semantic/src/types/subclass_of.rs b/crates/red_knot_python_semantic/src/types/subclass_of.rs index b69c7505c8..49c1d168fc 100644 --- a/crates/red_knot_python_semantic/src/types/subclass_of.rs +++ b/crates/red_knot_python_semantic/src/types/subclass_of.rs @@ -1,6 +1,6 @@ use crate::symbol::SymbolAndQualifiers; -use super::{ClassBase, ClassLiteralType, Db, KnownClass, MemberLookupPolicy, Type}; +use super::{ClassBase, Db, KnownClass, MemberLookupPolicy, Type}; /// A type that represents `type[C]`, i.e. the class object `C` and class objects that are subclasses of `C`. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)] @@ -27,7 +27,7 @@ impl<'db> SubclassOfType<'db> { ClassBase::Dynamic(_) => Type::SubclassOf(Self { subclass_of }), ClassBase::Class(class) => { if class.is_final(db) { - Type::ClassLiteral(ClassLiteralType { class }) + Type::from(class) } else if class.is_object(db) { KnownClass::Type.to_instance(db) } else { diff --git a/crates/red_knot_python_semantic/src/types/type_ordering.rs b/crates/red_knot_python_semantic/src/types/type_ordering.rs index 71252ebb25..c3f8177266 100644 --- a/crates/red_knot_python_semantic/src/types/type_ordering.rs +++ b/crates/red_knot_python_semantic/src/types/type_ordering.rs @@ -2,10 +2,7 @@ use std::cmp::Ordering; use crate::db::Db; -use super::{ - class_base::ClassBase, ClassLiteralType, DynamicType, InstanceType, KnownInstanceType, - TodoType, Type, -}; +use super::{class_base::ClassBase, DynamicType, InstanceType, KnownInstanceType, TodoType, Type}; /// Return an [`Ordering`] that describes the canonical order in which two types should appear /// in an [`crate::types::IntersectionType`] or a [`crate::types::UnionType`] in order for them @@ -93,13 +90,14 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::ModuleLiteral(_), _) => Ordering::Less, (_, Type::ModuleLiteral(_)) => Ordering::Greater, - ( - Type::ClassLiteral(ClassLiteralType { class: left }), - Type::ClassLiteral(ClassLiteralType { class: right }), - ) => left.cmp(right), + (Type::ClassLiteral(left), Type::ClassLiteral(right)) => left.cmp(right), (Type::ClassLiteral(_), _) => Ordering::Less, (_, Type::ClassLiteral(_)) => Ordering::Greater, + (Type::GenericAlias(left), Type::GenericAlias(right)) => left.cmp(right), + (Type::GenericAlias(_), _) => Ordering::Less, + (_, Type::GenericAlias(_)) => Ordering::Greater, + (Type::SubclassOf(left), Type::SubclassOf(right)) => { match (left.subclass_of(), right.subclass_of()) { (ClassBase::Class(left), ClassBase::Class(right)) => left.cmp(&right),