diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index fdd848f3dc..1d951ad3ed 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -1476,8 +1476,7 @@ class P1(Protocol): class P2(Protocol): def x(self, y: int) -> None: ... -# TODO: this should pass -static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error] +static_assert(is_equivalent_to(P1, P2)) ``` As with protocols that only have non-method members, this also holds true when they appear in @@ -1487,8 +1486,7 @@ differently ordered unions: class A: ... class B: ... -# TODO: this should pass -static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error] +static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) ``` ## Narrowing of protocols @@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)): reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A) ``` +### Protocols that use `Self` + +`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that +`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack +overflows. + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing_extensions import Protocol, Self +from ty_extensions import static_assert + +class _HashObject(Protocol): + def copy(self) -> Self: ... + +class Foo: ... + +# Attempting to build this union caused us to overflow on an early version of +# +x: Foo | _HashObject +``` + +Some other similar cases that caused issues in our early `Protocol` implementation: + +`a.py`: + +```py +from typing_extensions import Protocol, Self + +class PGconn(Protocol): + def connect(self) -> Self: ... + +class Connection: + pgconn: PGconn + +def is_crdb(conn: PGconn) -> bool: + return isinstance(conn, Connection) +``` + +and: + +`b.py`: + +```py +from typing_extensions import Protocol + +class PGconn(Protocol): + def connect[T: PGconn](self: T) -> T: ... + +class Connection: + pgconn: PGconn + +def f(x: PGconn): + isinstance(x, Connection) +``` + +### Recursive protocols used as the first argument to `cast()` + +These caused issues in an early version of our `Protocol` implementation due to the fact that we use +a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or +`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive: + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import cast, Protocol + +class Iterator[T](Protocol): + def __iter__(self) -> Iterator[T]: ... + +def f(value: Iterator): + cast(Iterator, value) # error: [redundant-cast] +``` + ## TODO Add tests for: diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 42cfa0f217..3a5314fbac 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -300,6 +300,20 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13])) static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12])) ``` +### Unions containing `Callable`s + +Two unions containing different `Callable` types are equivalent even if the unions are differently +ordered: + +```py +from ty_extensions import CallableTypeOf, Unknown, is_equivalent_to, static_assert + +def f(x): ... +def g(x: Unknown): ... + +static_assert(is_equivalent_to(CallableTypeOf[f] | int | str, str | int | CallableTypeOf[g])) +``` + ### Unions containing `Callable`s containing unions Differently ordered unions inside `Callable`s inside unions can still be equivalent: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index dd7bbd1c65..e156f0a474 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1102,7 +1102,7 @@ impl<'db> Type<'db> { Type::Dynamic(_) => Some(CallableType::single(db, Signature::dynamic(self))), Type::FunctionLiteral(function_literal) => { - Some(function_literal.into_callable_type(db)) + Some(Type::Callable(function_literal.into_callable_type(db))) } Type::BoundMethod(bound_method) => Some(bound_method.into_callable_type(db)), @@ -7336,6 +7336,10 @@ impl<'db> CallableType<'db> { /// /// See [`Type::is_equivalent_to`] for more details. fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + if self == other { + return true; + } + self.is_function_like(db) == other.is_function_like(db) && self .signatures(db) diff --git a/crates/ty_python_semantic/src/types/cyclic.rs b/crates/ty_python_semantic/src/types/cyclic.rs index 7922176f2a..2d1927d163 100644 --- a/crates/ty_python_semantic/src/types/cyclic.rs +++ b/crates/ty_python_semantic/src/types/cyclic.rs @@ -1,3 +1,5 @@ +use rustc_hash::FxHashMap; + use crate::FxIndexSet; use crate::types::Type; use std::cmp::Eq; @@ -19,14 +21,27 @@ pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>; #[derive(Debug)] pub(crate) struct CycleDetector { + /// If the type we're visiting is present in `seen`, + /// it indicates that we've hit a cycle (due to a recursive type); + /// we need to immediately short circuit the whole operation and return the fallback value. + /// That's why we pop items off the end of `seen` after we've visited them. seen: FxIndexSet, + + /// Unlike `seen`, this field is a pure performance optimisation (and an essential one). + /// If the type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit a cycle: + /// it just means that we've already visited this inner type as part of a bigger call chain we're currently in. + /// Since this cache is just a performance optimisation, it doesn't make sense to pop items off the end of the + /// cache after they've been visited (it would sort-of defeat the point of a cache if we did!) + cache: FxHashMap, + fallback: R, } -impl CycleDetector { +impl CycleDetector { pub(crate) fn new(fallback: R) -> Self { CycleDetector { seen: FxIndexSet::default(), + cache: FxHashMap::default(), fallback, } } @@ -35,7 +50,12 @@ impl CycleDetector { if !self.seen.insert(item) { return self.fallback; } + if let Some(ty) = self.cache.get(&item) { + self.seen.pop(); + return *ty; + } let ret = func(self); + self.cache.insert(item, ret); self.seen.pop(); ret } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 7b59835bd8..0a814f8e13 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> { } /// Convert the `FunctionType` into a [`Type::Callable`]. - pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::new(db, self.signature(db), false)) + pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> { + CallableType::new(db, self.signature(db), false) } /// Convert the `FunctionType` into a [`Type::BoundMethod`]. diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 33dff7b9c0..c92885adeb 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -270,7 +270,14 @@ impl<'db> ProtocolInstanceType<'db> { /// /// TODO: consider the types of the members as well as their existence pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { - self.normalized(db) == other.normalized(db) + if self == other { + return true; + } + let self_normalized = self.normalized(db); + if self_normalized == Type::ProtocolInstance(other) { + return true; + } + self_normalized == other.normalized(db) } /// Return `true` if this protocol type is disjoint from the protocol `other`. diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 3d86fabe23..a70a0c0dee 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -260,7 +260,7 @@ impl<'db> ProtocolMemberData<'db> { #[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)] enum ProtocolMemberKind<'db> { - Method(Type<'db>), // TODO: use CallableType + Method(CallableType<'db>), Property(PropertyInstanceType<'db>), Other(Type<'db>), } @@ -335,7 +335,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( visitor: &mut V, ) { match member.kind { - ProtocolMemberKind::Method(method) => visitor.visit_type(db, method), + ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method), ProtocolMemberKind::Property(property) => { visitor.visit_property_instance_type(db, property); } @@ -354,7 +354,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { fn ty(&self) -> Type<'db> { match &self.kind { - ProtocolMemberKind::Method(callable) => *callable, + ProtocolMemberKind::Method(callable) => Type::Callable(*callable), ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property), ProtocolMemberKind::Other(ty) => *ty, } @@ -508,13 +508,10 @@ fn cached_protocol_interface<'db>( (Type::Callable(callable), BoundOnClass::Yes) if callable.is_function_like(db) => { - ProtocolMemberKind::Method(ty) + ProtocolMemberKind::Method(callable) } - // TODO: method members that have `FunctionLiteral` types should be upcast - // to `CallableType` so that two protocols with identical method members - // are recognized as equivalent. - (Type::FunctionLiteral(_function), BoundOnClass::Yes) => { - ProtocolMemberKind::Method(ty) + (Type::FunctionLiteral(function), BoundOnClass::Yes) => { + ProtocolMemberKind::Method(function.into_callable_type(db)) } _ => ProtocolMemberKind::Other(ty), }; diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 58e3881a21..82b66c9a3d 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -1318,8 +1318,13 @@ impl<'db> Parameter<'db> { form, } = self; - // Ensure unions and intersections are ordered in the annotated type (if there is one) - let annotated_type = annotated_type.map(|ty| ty.normalized_impl(db, visitor)); + // Ensure unions and intersections are ordered in the annotated type (if there is one). + // Ensure that a parameter without an annotation is treated equivalently to a parameter + // with a dynamic type as its annotation. (We must use `Any` here as all dynamic types + // normalize to `Any`.) + let annotated_type = annotated_type + .map(|ty| ty.normalized_impl(db, visitor)) + .unwrap_or_else(Type::any); // Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters. // Ensure that we only record whether a parameter *has* a default @@ -1351,7 +1356,7 @@ impl<'db> Parameter<'db> { }; Self { - annotated_type, + annotated_type: Some(annotated_type), kind, form: *form, }