From de1f8177be9124a84f54219b3b5de300d0cafe67 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Sun, 29 Jun 2025 19:46:33 +0900 Subject: [PATCH] [ty] Improve protocol member type checking and relation handling (#18847) Co-authored-by: Alex Waygood --- .../resources/mdtest/annotations/callable.md | 15 +- .../mdtest/generics/legacy/classes.md | 4 +- .../resources/mdtest/narrow/hasattr.md | 7 + .../resources/mdtest/protocols.md | 170 ++++++++++++++- .../type_properties/is_disjoint_from.md | 61 ++++++ crates/ty_python_semantic/src/types.rs | 155 +++++++------ .../ty_python_semantic/src/types/instance.rs | 29 +-- crates/ty_python_semantic/src/types/narrow.rs | 6 +- .../src/types/protocol_class.rs | 204 +++++++++++++++--- 9 files changed, 519 insertions(+), 132 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md index 133a055ee9..c9bec98216 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md @@ -371,14 +371,23 @@ class MyCallable: f_wrong(MyCallable()) # raises `AttributeError` at runtime ``` -If users want to write to attributes such as `__qualname__`, they need to check the existence of the -attribute first: +If users want to read/write to attributes such as `__qualname__`, they need to check the existence +of the attribute first: ```py +from inspect import getattr_static + def f_okay(c: Callable[[], None]): if hasattr(c, "__qualname__"): c.__qualname__ # okay - c.__qualname__ = "my_callable" # also okay + # `hasattr` only guarantees that an attribute is readable. + # error: [invalid-assignment] "Object of type `Literal["my_callable"]` is not assignable to attribute `__qualname__` on type `(() -> None) & `" + c.__qualname__ = "my_callable" + + result = getattr_static(c, "__qualname__") + reveal_type(result) # revealed: Never + if isinstance(result, property) and result.fset: + c.__qualname__ = "my_callable" # okay ``` [gradual form]: https://typing.python.org/en/latest/spec/glossary.html#term-gradual-form diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md index 5e8858bf7b..a15bc32da7 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md @@ -482,8 +482,8 @@ reveal_type(c.method3()) # revealed: LinkedList[int] class SomeProtocol(Protocol[T]): x: T -class Foo: - x: int +class Foo(Generic[T]): + x: T class D(Generic[T, U]): x: T diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/hasattr.md b/crates/ty_python_semantic/resources/mdtest/narrow/hasattr.md index 394eb97588..c8148f54f6 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/hasattr.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/hasattr.md @@ -5,6 +5,7 @@ accomplished using an intersection with a synthesized protocol: ```py from typing import final +from typing_extensions import LiteralString class Foo: ... @@ -56,4 +57,10 @@ def h(obj: Baz): # TODO: should emit `[unresolved-attribute]` and reveal `Unknown` reveal_type(obj.x) # revealed: @Todo(map_with_boundness: intersections with negative contributions) + +def i(x: int | LiteralString): + if hasattr(x, "capitalize"): + reveal_type(x) # revealed: (int & ) | LiteralString + else: + reveal_type(x) # revealed: int & ~ ``` diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index 62aa534ae1..c1e052f582 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -489,35 +489,122 @@ python-version = "3.12" ``` ```py -from typing import Protocol +from typing import Protocol, Any, ClassVar +from collections.abc import Sequence from ty_extensions import static_assert, is_assignable_to, is_subtype_of class HasX(Protocol): x: int +class HasXY(Protocol): + x: int + y: int + class Foo: x: int static_assert(is_subtype_of(Foo, HasX)) static_assert(is_assignable_to(Foo, HasX)) +static_assert(not is_subtype_of(Foo, HasXY)) +static_assert(not is_assignable_to(Foo, HasXY)) class FooSub(Foo): ... static_assert(is_subtype_of(FooSub, HasX)) static_assert(is_assignable_to(FooSub, HasX)) +static_assert(not is_subtype_of(FooSub, HasXY)) +static_assert(not is_assignable_to(FooSub, HasXY)) + +class FooBool(Foo): + x: bool + +static_assert(not is_subtype_of(FooBool, HasX)) +static_assert(not is_assignable_to(FooBool, HasX)) + +class FooAny: + x: Any + +static_assert(not is_subtype_of(FooAny, HasX)) +static_assert(is_assignable_to(FooAny, HasX)) + +class SubclassOfAny(Any): ... + +class FooSubclassOfAny: + x: SubclassOfAny + +static_assert(not is_subtype_of(FooSubclassOfAny, HasX)) +static_assert(not is_assignable_to(FooSubclassOfAny, HasX)) + +class FooWithY(Foo): + y: int + +assert is_subtype_of(FooWithY, HasXY) +static_assert(is_assignable_to(FooWithY, HasXY)) class Bar: x: str -# TODO: these should pass -static_assert(not is_subtype_of(Bar, HasX)) # error: [static-assert-error] -static_assert(not is_assignable_to(Bar, HasX)) # error: [static-assert-error] +static_assert(not is_subtype_of(Bar, HasX)) +static_assert(not is_assignable_to(Bar, HasX)) class Baz: y: int static_assert(not is_subtype_of(Baz, HasX)) static_assert(not is_assignable_to(Baz, HasX)) + +class Qux: + def __init__(self, x: int) -> None: + self.x: int = x + +static_assert(is_subtype_of(Qux, HasX)) +static_assert(is_assignable_to(Qux, HasX)) + +class HalfUnknownQux: + def __init__(self, x: int) -> None: + self.x = x + +reveal_type(HalfUnknownQux(1).x) # revealed: Unknown | int + +static_assert(not is_subtype_of(HalfUnknownQux, HasX)) +static_assert(is_assignable_to(HalfUnknownQux, HasX)) + +class FullyUnknownQux: + def __init__(self, x) -> None: + self.x = x + +static_assert(not is_subtype_of(FullyUnknownQux, HasX)) +static_assert(is_assignable_to(FullyUnknownQux, HasX)) + +class HasXWithDefault(Protocol): + x: int = 0 + +class FooWithZero: + x: int = 0 + +# TODO: these should pass +static_assert(is_subtype_of(FooWithZero, HasXWithDefault)) # error: [static-assert-error] +static_assert(is_assignable_to(FooWithZero, HasXWithDefault)) # error: [static-assert-error] +static_assert(not is_subtype_of(Foo, HasXWithDefault)) +static_assert(not is_assignable_to(Foo, HasXWithDefault)) +static_assert(not is_subtype_of(Qux, HasXWithDefault)) +static_assert(not is_assignable_to(Qux, HasXWithDefault)) + +class HasClassVarX(Protocol): + x: ClassVar[int] + +static_assert(is_subtype_of(FooWithZero, HasClassVarX)) +static_assert(is_assignable_to(FooWithZero, HasClassVarX)) +# TODO: these should pass +static_assert(not is_subtype_of(Foo, HasClassVarX)) # error: [static-assert-error] +static_assert(not is_assignable_to(Foo, HasClassVarX)) # error: [static-assert-error] +static_assert(not is_subtype_of(Qux, HasClassVarX)) # error: [static-assert-error] +static_assert(not is_assignable_to(Qux, HasClassVarX)) # error: [static-assert-error] + +static_assert(is_subtype_of(Sequence[Foo], Sequence[HasX])) +static_assert(is_assignable_to(Sequence[Foo], Sequence[HasX])) +static_assert(not is_subtype_of(list[Foo], list[HasX])) +static_assert(not is_assignable_to(list[Foo], list[HasX])) ``` Note that declaring an attribute member on a protocol mandates that the attribute must be mutable. A @@ -552,10 +639,8 @@ class C: # due to invariance, a type is only a subtype of `HasX` # if its `x` attribute is of type *exactly* `int`: # a subclass of `int` does not satisfy the interface -# -# TODO: these should pass -static_assert(not is_subtype_of(C, HasX)) # error: [static-assert-error] -static_assert(not is_assignable_to(C, HasX)) # error: [static-assert-error] +static_assert(not is_subtype_of(C, HasX)) +static_assert(not is_assignable_to(C, HasX)) ``` All attributes on frozen dataclasses and namedtuples are immutable, so instances of these classes @@ -1229,6 +1314,62 @@ static_assert(is_subtype_of(HasGetAttrAndSetAttr, XAsymmetricProperty)) # error static_assert(is_assignable_to(HasGetAttrAndSetAttr, XAsymmetricProperty)) # error: [static-assert-error] ``` +## Subtyping of protocols with method members + +A protocol can have method members. `T` is assignable to `P` in the following example because the +class `T` has a method `m` which is assignable to the `Callable` supertype of the method `P.m`: + +```py +from typing import Protocol +from ty_extensions import is_subtype_of, static_assert + +class P(Protocol): + def m(self, x: int, /) -> None: ... + +class NominalSubtype: + def m(self, y: int) -> None: ... + +class NotSubtype: + def m(self, x: int) -> int: + return 42 + +static_assert(is_subtype_of(NominalSubtype, P)) + +# TODO: should pass +static_assert(not is_subtype_of(NotSubtype, P)) # error: [static-assert-error] +``` + +## Equivalence of protocols with method members + +Two protocols `P1` and `P2`, both with a method member `x`, are considered equivalent if the +signature of `P1.x` is equivalent to the signature of `P2.x`, even though ty would normally model +any two function definitions as inhabiting distinct function-literal types. + +```py +from typing import Protocol +from ty_extensions import is_equivalent_to, static_assert + +class P1(Protocol): + def x(self, y: int) -> None: ... + +class P2(Protocol): + def x(self, y: int) -> None: ... + +# TODO: this should pass +static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error] +``` + +As with protocols that only have non-method members, this also holds true when they appear in +differently ordered unions: + +```py +class A: ... +class B: ... + +# TODO: this should pass +static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error] +``` + ## Narrowing of protocols @@ -1458,7 +1599,7 @@ def two(some_list: list, some_tuple: tuple[int, str], some_sized: Sized): ```py from __future__ import annotations -from typing import Protocol, Any +from typing import Protocol, Any, TypeVar from ty_extensions import static_assert, is_assignable_to, is_subtype_of, is_equivalent_to class RecursiveFullyStatic(Protocol): @@ -1514,6 +1655,17 @@ class Bar(Protocol): # TODO: this should pass # error: [static-assert-error] static_assert(is_equivalent_to(Foo, Bar)) + +T = TypeVar("T", bound="TypeVarRecursive") + +class TypeVarRecursive(Protocol): + # TODO: commenting this out will cause a stack overflow. + # x: T + y: "TypeVarRecursive" + +def _(t: TypeVarRecursive): + # reveal_type(t.x) # revealed: T + reveal_type(t.y) # revealed: TypeVarRecursive ``` ### Nested occurrences of self-reference diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md index b3ae6150b5..9441cc1334 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md @@ -501,6 +501,67 @@ static_assert(is_disjoint_from(str, TypeGuard[str])) # error: [static-assert-er static_assert(is_disjoint_from(str, TypeIs[str])) ``` +### `Protocol` + +A protocol is disjoint from another type if any of the protocol's members are available as an +attribute on the other type *but* the type of the attribute on the other type is disjoint from the +type of the protocol's member. + +```py +from typing_extensions import Protocol, Literal, final, ClassVar +from ty_extensions import is_disjoint_from, static_assert + +class HasAttrA(Protocol): + attr: Literal["a"] + +class SupportsInt(Protocol): + def __int__(self) -> int: ... + +class A: + attr: Literal["a"] + +class B: + attr: Literal["b"] + +class C: + foo: int + +class D: + attr: int + +@final +class E: + pass + +@final +class F: + def __int__(self) -> int: + return 1 + +static_assert(not is_disjoint_from(HasAttrA, A)) +static_assert(is_disjoint_from(HasAttrA, B)) +# A subclass of E may satisfy HasAttrA +static_assert(not is_disjoint_from(HasAttrA, C)) +static_assert(is_disjoint_from(HasAttrA, D)) +static_assert(is_disjoint_from(HasAttrA, E)) + +static_assert(is_disjoint_from(SupportsInt, E)) +static_assert(not is_disjoint_from(SupportsInt, F)) + +class NotIterable(Protocol): + __iter__: ClassVar[None] + +static_assert(is_disjoint_from(tuple[int, int], NotIterable)) + +class Foo: + BAR: ClassVar[int] + +class BarNone(Protocol): + BAR: None + +static_assert(is_disjoint_from(type[Foo], BarNone)) +``` + ## Callables No two callable types are disjoint because there exists a non-empty callable type diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 54b76f562a..c258c10931 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -404,6 +404,22 @@ impl<'db> PropertyInstanceType<'db> { ty.find_legacy_typevars(db, typevars); } } + + fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { + Self::new( + db, + self.getter(db).map(|ty| ty.materialize(db, variance)), + self.setter(db).map(|ty| ty.materialize(db, variance)), + ) + } + + fn any_over_type(self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool { + self.getter(db) + .is_some_and(|ty| ty.any_over_type(db, type_fn)) + || self + .setter(db) + .is_some_and(|ty| ty.any_over_type(db, type_fn)) + } } bitflags! { @@ -681,10 +697,13 @@ impl<'db> Type<'db> { | Type::KnownInstance(_) | Type::AlwaysFalsy | Type::AlwaysTruthy - | Type::PropertyInstance(_) | Type::ClassLiteral(_) | Type::BoundSuper(_) => *self, + Type::PropertyInstance(property_instance) => { + Type::PropertyInstance(property_instance.materialize(db, variance)) + } + Type::FunctionLiteral(_) | Type::BoundMethod(_) => { // TODO: Subtyping between function / methods with a callable accounts for the // signature (parameters and return type), so we might need to do something here @@ -902,15 +921,7 @@ impl<'db> Type<'db> { } Self::ProtocolInstance(protocol) => protocol.any_over_type(db, type_fn), - - Self::PropertyInstance(property) => { - property - .getter(db) - .is_some_and(|ty| ty.any_over_type(db, type_fn)) - || property - .setter(db) - .is_some_and(|ty| ty.any_over_type(db, type_fn)) - } + Self::PropertyInstance(property) => property.any_over_type(db, type_fn), Self::NominalInstance(instance) => match instance.class { ClassType::NonGeneric(_) => false, @@ -1453,7 +1464,9 @@ impl<'db> Type<'db> { } // A protocol instance can never be a subtype of a nominal type, with the *sole* exception of `object`. (Type::ProtocolInstance(_), _) => false, - (_, Type::ProtocolInstance(protocol)) => self.satisfies_protocol(db, protocol), + (_, Type::ProtocolInstance(protocol)) => { + self.satisfies_protocol(db, protocol, relation) + } // All `StringLiteral` types are a subtype of `LiteralString`. (Type::StringLiteral(_), Type::LiteralString) => true, @@ -1865,26 +1878,6 @@ impl<'db> Type<'db> { Type::Tuple(..), ) => true, - (Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b)) - | (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => { - match subclass_of_ty.subclass_of() { - SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class_a) => !class_b.is_subclass_of(db, None, class_a), - } - } - - (Type::SubclassOf(subclass_of_ty), Type::GenericAlias(alias_b)) - | (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => { - match subclass_of_ty.subclass_of() { - SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class_a) => { - !ClassType::from(alias_b).is_subclass_of(db, class_a) - } - } - } - - (Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from(db, right), - ( Type::SubclassOf(_), Type::BooleanLiteral(..) @@ -1912,28 +1905,6 @@ impl<'db> Type<'db> { Type::SubclassOf(_), ) => true, - (Type::AlwaysTruthy, ty) | (ty, Type::AlwaysTruthy) => { - // `Truthiness::Ambiguous` may include `AlwaysTrue` as a subset, so it's not guaranteed to be disjoint. - // Thus, they are only disjoint if `ty.bool() == AlwaysFalse`. - ty.bool(db).is_always_false() - } - (Type::AlwaysFalsy, ty) | (ty, Type::AlwaysFalsy) => { - // Similarly, they are only disjoint if `ty.bool() == AlwaysTrue`. - ty.bool(db).is_always_true() - } - - (Type::ProtocolInstance(left), Type::ProtocolInstance(right)) => { - left.is_disjoint_from(db, right) - } - - // TODO: we could also consider `protocol` to be disjoint from `nominal` if `nominal` - // has the right member but the type of its member is disjoint from the type of the - // member on `protocol`. - (Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n)) - | (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => { - n.class.is_final(db) && !nominal.satisfies_protocol(db, protocol) - } - ( ty @ (Type::LiteralString | Type::StringLiteral(..) @@ -1957,36 +1928,75 @@ impl<'db> Type<'db> { | Type::ModuleLiteral(..) | Type::GenericAlias(..) | Type::IntLiteral(..)), - ) => !ty.satisfies_protocol(db, protocol), + ) => !ty.satisfies_protocol(db, protocol, TypeRelation::Assignability), + + (Type::AlwaysTruthy, ty) | (ty, Type::AlwaysTruthy) => { + // `Truthiness::Ambiguous` may include `AlwaysTrue` as a subset, so it's not guaranteed to be disjoint. + // Thus, they are only disjoint if `ty.bool() == AlwaysFalse`. + ty.bool(db).is_always_false() + } + (Type::AlwaysFalsy, ty) | (ty, Type::AlwaysFalsy) => { + // Similarly, they are only disjoint if `ty.bool() == AlwaysTrue`. + ty.bool(db).is_always_true() + } + + (Type::ProtocolInstance(left), Type::ProtocolInstance(right)) => { + left.is_disjoint_from(db, right) + } (Type::ProtocolInstance(protocol), Type::SpecialForm(special_form)) | (Type::SpecialForm(special_form), Type::ProtocolInstance(protocol)) => !special_form .instance_fallback(db) - .satisfies_protocol(db, protocol), + .satisfies_protocol(db, protocol, TypeRelation::Assignability), (Type::ProtocolInstance(protocol), Type::KnownInstance(known_instance)) | (Type::KnownInstance(known_instance), Type::ProtocolInstance(protocol)) => { - !known_instance - .instance_fallback(db) - .satisfies_protocol(db, protocol) + !known_instance.instance_fallback(db).satisfies_protocol( + db, + protocol, + TypeRelation::Assignability, + ) } - (Type::Callable(_), Type::ProtocolInstance(_)) - | (Type::ProtocolInstance(_), Type::Callable(_)) => { - // TODO disjointness between `Callable` and `ProtocolInstance` - false + (Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n)) + | (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) + if n.class.is_final(db) => + { + !nominal.satisfies_protocol(db, protocol, TypeRelation::Assignability) } - (Type::Tuple(..), Type::ProtocolInstance(..)) - | (Type::ProtocolInstance(..), Type::Tuple(..)) => { - // Currently we do not make any general assumptions about the disjointness of a `Tuple` type - // and a `ProtocolInstance` type because a `Tuple` type can be an instance of a tuple - // subclass. - // - // TODO when we capture the types of the protocol members, we can improve on this. - false + (Type::ProtocolInstance(protocol), other) + | (other, Type::ProtocolInstance(protocol)) => { + protocol.interface(db).members(db).any(|member| { + // TODO: implement disjointness for property/method members as well as attribute members + member.is_attribute_member() + && matches!( + other.member(db, member.name()).place, + Place::Type(ty, Boundness::Bound) if ty.is_disjoint_from(db, member.ty()) + ) + }) } + (Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b)) + | (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => { + match subclass_of_ty.subclass_of() { + SubclassOfInner::Dynamic(_) => false, + SubclassOfInner::Class(class_a) => !class_b.is_subclass_of(db, None, class_a), + } + } + + (Type::SubclassOf(subclass_of_ty), Type::GenericAlias(alias_b)) + | (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => { + match subclass_of_ty.subclass_of() { + SubclassOfInner::Dynamic(_) => false, + SubclassOfInner::Class(class_a) => { + !ClassType::from(alias_b).is_subclass_of(db, class_a) + } + } + } + + (Type::SubclassOf(left), Type::SubclassOf(right)) => left.is_disjoint_from(db, right), + // for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`, // so although the type is dynamic we can still determine disjointedness in some situations (Type::SubclassOf(subclass_of_ty), other) @@ -2531,6 +2541,11 @@ impl<'db> Type<'db> { Type::Intersection(inter) => inter.map_with_boundness_and_qualifiers(db, |elem| { elem.class_member_with_policy(db, name.clone(), policy) }), + // TODO: Once `to_meta_type` for the synthesized protocol is fully implemented, this handling should be removed. + Type::ProtocolInstance(ProtocolInstanceType { + inner: Protocol::Synthesized(_), + .. + }) => self.instance_member(db, &name), _ => self .to_meta_type(db) .find_name_in_mro_with_policy(db, name.as_str(), policy) diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index bc609339e8..8cf1091ad7 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -4,7 +4,7 @@ use std::marker::PhantomData; use super::protocol_class::ProtocolInterface; use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance}; -use crate::place::{Boundness, Place, PlaceAndQualifiers}; +use crate::place::{Place, PlaceAndQualifiers}; use crate::types::tuple::TupleType; use crate::types::{ClassLiteral, DynamicType, TypeMapping, TypeRelation, TypeVarInstance}; use crate::{Db, FxOrderSet}; @@ -35,31 +35,28 @@ impl<'db> Type<'db> { } } - pub(super) fn synthesized_protocol<'a, M>(db: &'db dyn Db, members: M) -> Self + /// Synthesize a protocol instance type with a given set of read-only property members. + pub(super) fn protocol_with_readonly_members<'a, M>(db: &'db dyn Db, members: M) -> Self where M: IntoIterator)>, { Self::ProtocolInstance(ProtocolInstanceType::synthesized( - SynthesizedProtocolType::new(db, ProtocolInterface::with_members(db, members)), + SynthesizedProtocolType::new(db, ProtocolInterface::with_property_members(db, members)), )) } /// Return `true` if `self` conforms to the interface described by `protocol`. - /// - /// TODO: we may need to split this into two methods in the future, once we start - /// differentiating between fully-static and non-fully-static protocols. pub(super) fn satisfies_protocol( self, db: &'db dyn Db, protocol: ProtocolInstanceType<'db>, + relation: TypeRelation, ) -> bool { - // TODO: this should consider the types of the protocol members - protocol.inner.interface(db).members(db).all(|member| { - matches!( - self.member(db, member.name()).place, - Place::Type(_, Boundness::Bound) - ) - }) + protocol + .inner + .interface(db) + .members(db) + .all(|member| member.is_satisfied_by(db, self, relation)) } } @@ -205,7 +202,7 @@ impl<'db> ProtocolInstanceType<'db> { /// See [`Type::normalized`] for more details. pub(super) fn normalized(self, db: &'db dyn Db) -> Type<'db> { let object = KnownClass::Object.to_instance(db); - if object.satisfies_protocol(db, self) { + if object.satisfies_protocol(db, self, TypeRelation::Subtyping) { return object; } match self.inner { @@ -322,6 +319,10 @@ impl<'db> ProtocolInstanceType<'db> { } } } + + pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { + self.inner.interface(db) + } } /// An enumeration of the two kinds of protocol types: those that originate from a class diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 228f83563a..e79ae39e3a 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -832,9 +832,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let constraint = Type::synthesized_protocol( + // Since `hasattr` only checks if an attribute is readable, + // the type of the protocol member should be a read-only property that returns `object`. + let constraint = Type::protocol_with_readonly_members( self.db, - [(attr, KnownClass::Object.to_instance(self.db))], + [(attr, Type::object(self.db))], ); return Some(NarrowingConstraints::from_iter([( diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 2c4b167072..ff79a7b313 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -6,10 +6,12 @@ use ruff_python_ast::name::Name; use crate::{ Db, FxOrderSet, - place::{place_from_bindings, place_from_declarations}, + place::{Boundness, Place, place_from_bindings, place_from_declarations}, semantic_index::{place_table, use_def_map}, types::{ - ClassBase, ClassLiteral, KnownFunction, Type, TypeMapping, TypeQualifiers, TypeVarInstance, + CallableType, ClassBase, ClassLiteral, KnownFunction, PropertyInstanceType, Signature, + Type, TypeMapping, TypeQualifiers, TypeRelation, TypeVarInstance, + signatures::{Parameter, Parameters}, }, }; @@ -82,18 +84,30 @@ pub(super) enum ProtocolInterface<'db> { } impl<'db> ProtocolInterface<'db> { - pub(super) fn with_members<'a, M>(db: &'db dyn Db, members: M) -> Self + /// Synthesize a new protocol interface with the given members. + /// + /// All created members will be covariant, read-only property members + /// rather than method members or mutable attribute members. + pub(super) fn with_property_members<'a, M>(db: &'db dyn Db, members: M) -> Self where M: IntoIterator)>, { let members: BTreeMap<_, _> = members .into_iter() .map(|(name, ty)| { + // Synthesize a read-only property (one that has a getter but no setter) + // which returns the specified type from its getter. + let property_getter_signature = Signature::new( + Parameters::new([Parameter::positional_only(Some(Name::new_static("self")))]), + Some(ty.normalized(db)), + ); + let property_getter = CallableType::single(db, property_getter_signature); + let property = PropertyInstanceType::new(db, Some(property_getter), None); ( Name::new(name), ProtocolMemberData { - ty: ty.normalized(db), qualifiers: TypeQualifiers::default(), + kind: ProtocolMemberKind::Property(property), }, ) }) @@ -116,7 +130,7 @@ impl<'db> ProtocolInterface<'db> { Self::Members(members) => { Either::Left(members.inner(db).iter().map(|(name, data)| ProtocolMember { name, - ty: data.ty, + kind: data.kind, qualifiers: data.qualifiers, })) } @@ -132,7 +146,7 @@ impl<'db> ProtocolInterface<'db> { match self { Self::Members(members) => members.inner(db).get(name).map(|data| ProtocolMember { name, - ty: data.ty, + kind: data.kind, qualifiers: data.qualifiers, }), Self::SelfReference => None, @@ -161,7 +175,7 @@ impl<'db> ProtocolInterface<'db> { type_fn: &dyn Fn(Type<'db>) -> bool, ) -> bool { self.members(db) - .any(|member| member.ty.any_over_type(db, type_fn)) + .any(|member| member.any_over_type(db, type_fn)) } pub(super) fn normalized(self, db: &'db dyn Db) -> Self { @@ -185,15 +199,7 @@ impl<'db> ProtocolInterface<'db> { members .inner(db) .iter() - .map(|(name, data)| { - ( - name.clone(), - ProtocolMemberData { - ty: data.ty.materialize(db, variance), - qualifiers: data.qualifiers, - }, - ) - }) + .map(|(name, data)| (name.clone(), data.materialize(db, variance))) .collect::>(), )), Self::SelfReference => Self::SelfReference, @@ -241,21 +247,21 @@ impl<'db> ProtocolInterface<'db> { #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] pub(super) struct ProtocolMemberData<'db> { - ty: Type<'db>, + kind: ProtocolMemberKind<'db>, qualifiers: TypeQualifiers, } impl<'db> ProtocolMemberData<'db> { fn normalized(&self, db: &'db dyn Db) -> Self { Self { - ty: self.ty.normalized(db), + kind: self.kind.normalized(db), qualifiers: self.qualifiers, } } fn apply_type_mapping<'a>(&self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>) -> Self { Self { - ty: self.ty.apply_type_mapping(db, type_mapping), + kind: self.kind.apply_type_mapping(db, type_mapping), qualifiers: self.qualifiers, } } @@ -265,7 +271,75 @@ impl<'db> ProtocolMemberData<'db> { db: &'db dyn Db, typevars: &mut FxOrderSet>, ) { - self.ty.find_legacy_typevars(db, typevars); + self.kind.find_legacy_typevars(db, typevars); + } + + fn materialize(&self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { + Self { + kind: self.kind.materialize(db, variance), + qualifiers: self.qualifiers, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)] +enum ProtocolMemberKind<'db> { + Method(Type<'db>), // TODO: use CallableType + Property(PropertyInstanceType<'db>), + Other(Type<'db>), +} + +impl<'db> ProtocolMemberKind<'db> { + fn normalized(&self, db: &'db dyn Db) -> Self { + match self { + ProtocolMemberKind::Method(callable) => { + ProtocolMemberKind::Method(callable.normalized(db)) + } + ProtocolMemberKind::Property(property) => { + ProtocolMemberKind::Property(property.normalized(db)) + } + ProtocolMemberKind::Other(ty) => ProtocolMemberKind::Other(ty.normalized(db)), + } + } + + fn apply_type_mapping<'a>(&self, db: &'db dyn Db, type_mapping: &TypeMapping<'a, 'db>) -> Self { + match self { + ProtocolMemberKind::Method(callable) => { + ProtocolMemberKind::Method(callable.apply_type_mapping(db, type_mapping)) + } + ProtocolMemberKind::Property(property) => { + ProtocolMemberKind::Property(property.apply_type_mapping(db, type_mapping)) + } + ProtocolMemberKind::Other(ty) => { + ProtocolMemberKind::Other(ty.apply_type_mapping(db, type_mapping)) + } + } + } + + fn find_legacy_typevars( + &self, + db: &'db dyn Db, + typevars: &mut FxOrderSet>, + ) { + match self { + ProtocolMemberKind::Method(callable) => callable.find_legacy_typevars(db, typevars), + ProtocolMemberKind::Property(property) => property.find_legacy_typevars(db, typevars), + ProtocolMemberKind::Other(ty) => ty.find_legacy_typevars(db, typevars), + } + } + + fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { + match self { + ProtocolMemberKind::Method(callable) => { + ProtocolMemberKind::Method(callable.materialize(db, variance)) + } + ProtocolMemberKind::Property(property) => { + ProtocolMemberKind::Property(property.materialize(db, variance)) + } + ProtocolMemberKind::Other(ty) => { + ProtocolMemberKind::Other(ty.materialize(db, variance)) + } + } } } @@ -273,7 +347,7 @@ impl<'db> ProtocolMemberData<'db> { #[derive(Debug, PartialEq, Eq)] pub(super) struct ProtocolMember<'a, 'db> { name: &'a str, - ty: Type<'db>, + kind: ProtocolMemberKind<'db>, qualifiers: TypeQualifiers, } @@ -282,13 +356,52 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { self.name } - pub(super) fn ty(&self) -> Type<'db> { - self.ty - } - pub(super) fn qualifiers(&self) -> TypeQualifiers { self.qualifiers } + + pub(super) fn ty(&self) -> Type<'db> { + match &self.kind { + ProtocolMemberKind::Method(callable) => *callable, + ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property), + ProtocolMemberKind::Other(ty) => *ty, + } + } + + pub(super) const fn is_attribute_member(&self) -> bool { + matches!(self.kind, ProtocolMemberKind::Other(_)) + } + + /// Return `true` if `other` contains an attribute/method/property that satisfies + /// the part of the interface defined by this protocol member. + pub(super) fn is_satisfied_by( + &self, + db: &'db dyn Db, + other: Type<'db>, + relation: TypeRelation, + ) -> bool { + let Place::Type(attribute_type, Boundness::Bound) = other.member(db, self.name).place + else { + return false; + }; + + match &self.kind { + // TODO: consider the types of the attribute on `other` for property/method members + ProtocolMemberKind::Method(_) | ProtocolMemberKind::Property(_) => true, + ProtocolMemberKind::Other(member_type) => { + member_type.has_relation_to(db, attribute_type, relation) + && attribute_type.has_relation_to(db, *member_type, relation) + } + } + } + + fn any_over_type(&self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool { + match &self.kind { + ProtocolMemberKind::Method(callable) => callable.any_over_type(db, type_fn), + ProtocolMemberKind::Property(property) => property.any_over_type(db, type_fn), + ProtocolMemberKind::Other(ty) => ty.any_over_type(db, type_fn), + } + } } /// Returns `true` if a declaration or binding to a given name in a protocol class body @@ -330,6 +443,12 @@ fn excluded_from_proto_members(member: &str) -> bool { ) || member.starts_with("_abc_") } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum BoundOnClass { + Yes, + No, +} + /// Inner Salsa query for [`ProtocolClassLiteral::interface`]. #[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)] fn cached_protocol_interface<'db>( @@ -357,7 +476,7 @@ fn cached_protocol_interface<'db>( place .place .ignore_possibly_unbound() - .map(|ty| (place_id, ty, place.qualifiers)) + .map(|ty| (place_id, ty, place.qualifiers, BoundOnClass::No)) }) // Bindings in the class body that are not declared in the class body // are not valid protocol members, and we plan to emit diagnostics for them @@ -371,20 +490,41 @@ fn cached_protocol_interface<'db>( |(place_id, bindings)| { place_from_bindings(db, bindings) .ignore_possibly_unbound() - .map(|ty| (place_id, ty, TypeQualifiers::default())) + .map(|ty| (place_id, ty, TypeQualifiers::default(), BoundOnClass::Yes)) }, )) - .filter_map(|(place_id, member, qualifiers)| { + .filter_map(|(place_id, member, qualifiers, bound_on_class)| { Some(( place_table.place_expr(place_id).as_name()?, member, qualifiers, + bound_on_class, )) }) - .filter(|(name, _, _)| !excluded_from_proto_members(name)) - .map(|(name, ty, qualifiers)| { - let ty = ty.replace_self_reference(db, class); - let member = ProtocolMemberData { ty, qualifiers }; + .filter(|(name, _, _, _)| !excluded_from_proto_members(name)) + .map(|(name, ty, qualifiers, bound_on_class)| { + let kind = match (ty, bound_on_class) { + // TODO: if the getter or setter is a function literal, we should + // upcast it to a `CallableType` so that two protocols with identical property + // members are recognized as equivalent. + (Type::PropertyInstance(property), _) => { + ProtocolMemberKind::Property(property) + } + (Type::Callable(callable), BoundOnClass::Yes) + if callable.is_function_like(db) => + { + ProtocolMemberKind::Method(ty.replace_self_reference(db, class)) + } + // TODO: method members that have `FunctionLiteral` types should be upcast + // to `CallableType` so that two protocols with identical method members + // are recognized as equivalent. + (Type::FunctionLiteral(_function), BoundOnClass::Yes) => { + ProtocolMemberKind::Method(ty.replace_self_reference(db, class)) + } + _ => ProtocolMemberKind::Other(ty.replace_self_reference(db, class)), + }; + + let member = ProtocolMemberData { kind, qualifiers }; (name.clone(), member) }), );