From b5814b91c182d1be5762d69a1c6db804ebfb7755 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 13 Jan 2026 20:24:05 -0500 Subject: [PATCH] [ty] Add diagnostics to validate `TypeIs` and `TypeGuard` definitions (#22300) ## Summary Closes https://github.com/astral-sh/ty/issues/2267. --- .../mdtest/generics/pep695/variance.md | 4 +- .../resources/mdtest/narrow/type_guards.md | 47 +++++++---- .../src/types/infer/builder.rs | 84 ++++++++++++++++++- 3 files changed, 114 insertions(+), 21 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md index a04118988e..cb55bc5cd7 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md @@ -797,7 +797,7 @@ class B(A): pass class C[T]: - def check(x: object) -> TypeIs[T]: + def check(self, x: object) -> TypeIs[T]: # this is a bad check, but we only care about it type-checking return False @@ -835,7 +835,7 @@ class B(A): pass class C[T]: - def check(x: object) -> TypeGuard[T]: + def check(self, x: object) -> TypeGuard[T]: # this is a bad check, but we only care about it type-checking return False diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index 170dab1162..1036f61df2 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -14,8 +14,8 @@ def _( b: TypeIs[str | int], c: TypeGuard[bool], d: TypeIs[tuple[TypeOf[bytes]]], - e: TypeGuard, # error: [invalid-type-form] - f: TypeIs, # error: [invalid-type-form] + e: TypeGuard, # error: [invalid-type-form] "`typing.TypeGuard` requires exactly one argument when used in a type expression" + f: TypeIs, # error: [invalid-type-form] "`typing.TypeIs` requires exactly one argument when used in a type expression" ): reveal_type(a) # revealed: TypeGuard[str] reveal_type(b) # revealed: TypeIs[str | int] @@ -46,12 +46,23 @@ A user-defined type guard must accept at least one positional argument (in addit for non-static methods). ```pyi +from typing import Any, TypeVar from typing_extensions import TypeGuard, TypeIs -# TODO: error: [invalid-type-guard-definition] +T = TypeVar("T") + +# Multiple parameters are allowed +def is_str_list(val: list[object], allow_empty: bool) -> TypeGuard[list[str]]: ... +def is_set_of(val: set[Any], type: type[T]) -> TypeGuard[set[T]]: ... +def is_two_element_tuple(val: tuple[object, ...], a: str, b: str) -> TypeIs[tuple[str, str]]: ... + +# error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" def _() -> TypeGuard[str]: ... -# TODO: error: [invalid-type-guard-definition] +# error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" +def _(*args) -> TypeGuard[str]: ... + +# error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow" def _(**kwargs) -> TypeIs[str]: ... class _: @@ -63,14 +74,14 @@ class _: def _(a) -> TypeIs[str]: ... # errors - def _(self) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition] - def _(self, /, *, a) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition] + def _(self) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" + def _(self, /, *, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" @classmethod - def _(cls) -> TypeIs[str]: ... # TODO: error: [invalid-type-guard-definition] + def _(cls) -> TypeIs[str]: ... # error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow" @classmethod - def _() -> TypeIs[str]: ... # TODO: error: [invalid-type-guard-definition] + def _() -> TypeIs[str]: ... # error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow" @staticmethod - def _(*, a) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition] + def _(*, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" ``` For `TypeIs` functions, the narrowed type must be assignable to the declared type of that parameter, @@ -86,10 +97,10 @@ def _(a: tuple[object]) -> TypeIs[tuple[str]]: ... def _(a: str | Any) -> TypeIs[str]: ... def _(a) -> TypeIs[str]: ... -# TODO: error: [invalid-type-guard-definition] +# error: [invalid-type-guard-definition] "Narrowed type `str` is not assignable to the declared parameter type `int`" def _(a: int) -> TypeIs[str]: ... -# TODO: error: [invalid-type-guard-definition] +# error: [invalid-type-guard-definition] "Narrowed type `int` is not assignable to the declared parameter type `bool | str`" def _(a: bool | str) -> TypeIs[int]: ... ``` @@ -107,12 +118,14 @@ class C: @classmethod def g(cls, x: object) -> TypeGuard[int]: return True - # TODO: this could error at definition time - def h(self) -> TypeGuard[str]: + + def h( + self, + ) -> TypeGuard[str]: # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" return True - # TODO: this could error at definition time + @classmethod - def j(cls) -> TypeGuard[int]: + def j(cls) -> TypeGuard[int]: # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow" return True def _(x: object): @@ -221,7 +234,7 @@ def g(a: object) -> TypeIs[int]: return True def _(d: Any): - if f(): # error: [missing-argument] + if f(): # error: [missing-argument] "No argument provided for required parameter `a` of function `f`" ... if g(*d): @@ -230,7 +243,7 @@ def _(d: Any): if f("foo"): # TODO: error: [invalid-type-guard-call] ... - if g(a=d): # error: [invalid-type-guard-call] + if g(a=d): # error: [invalid-type-guard-call] "Type guard call does not have a target" ... ``` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 214584521c..242cd542cf 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -68,8 +68,8 @@ use crate::types::diagnostic::{ INVALID_GENERIC_ENUM, INVALID_KEY, INVALID_LEGACY_TYPE_VARIABLE, INVALID_METACLASS, INVALID_NAMED_TUPLE, INVALID_NEWTYPE, INVALID_OVERLOAD, INVALID_PARAMETER_DEFAULT, INVALID_PARAMSPEC, INVALID_PROTOCOL, INVALID_TYPE_ARGUMENTS, INVALID_TYPE_FORM, - INVALID_TYPE_GUARD_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, INVALID_TYPED_DICT_STATEMENT, - IncompatibleBases, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_ATTRIBUTE, + INVALID_TYPE_GUARD_CALL, INVALID_TYPE_GUARD_DEFINITION, INVALID_TYPE_VARIABLE_CONSTRAINTS, + INVALID_TYPED_DICT_STATEMENT, IncompatibleBases, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_ATTRIBUTE, POSSIBLY_MISSING_IMPLICIT_CALL, POSSIBLY_MISSING_IMPORT, SUBCLASS_OF_FINAL_CLASS, TypedDictDeleteErrorKind, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_GLOBAL, UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNSUPPORTED_DYNAMIC_BASE, UNSUPPORTED_OPERATOR, @@ -587,6 +587,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if self.db().should_check_file(self.file()) { self.check_static_class_definitions(); self.check_overloaded_functions(node); + self.check_type_guard_definitions(); } } @@ -1452,6 +1453,85 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } + /// Check that all type guard function definitions have at least one positional parameter + /// (in addition to `self`/`cls` for methods), and for `TypeIs`, that the narrowed type is + /// assignable to the declared type of that parameter. + fn check_type_guard_definitions(&mut self) { + for (definition, ty) in self.declarations.iter() { + // Only check actual function definitions, not imports. + let DefinitionKind::Function(function_ref) = definition.kind(self.db()) else { + continue; + }; + + let Some(function) = ty.inner_type().as_function_literal() else { + continue; + }; + + for overload in function.iter_overloads_and_implementation(self.db()) { + let signature = overload.signature(self.db()); + let return_ty = signature.return_ty; + + // Check if this is a `TypeIs` or `TypeGuard` return type. + let (type_guard_form_name, narrowed_type) = match return_ty { + Type::TypeIs(type_is) => ("TypeIs", Some(type_is.return_type(self.db()))), + Type::TypeGuard(_) => ("TypeGuard", None), + _ => continue, + }; + + let function_node = function_ref.node(self.module()); + + // The return type annotation must exist since we matched `TypeIs`/`TypeGuard`. + let Some(returns_expr) = function_node.returns.as_deref() else { + continue; + }; + + // Check if this is a non-static method (first parameter is implicit `self`/`cls`). + let is_method = self + .index + .class_definition_of_method( + overload.body_scope(self.db()).file_scope_id(self.db()), + ) + .is_some(); + let has_implicit_receiver = is_method && !overload.is_staticmethod(self.db()); + + // Find the first positional parameter to narrow (skip implicit `self`/`cls`). + let positional_params: Vec<_> = signature.parameters().positional().collect(); + let first_narrowed_param_index = usize::from(has_implicit_receiver); + let first_narrowed_param = positional_params.get(first_narrowed_param_index); + + let Some(first_narrowed_param) = first_narrowed_param else { + if let Some(builder) = self + .context + .report_lint(&INVALID_TYPE_GUARD_DEFINITION, returns_expr) + { + builder.into_diagnostic(format_args!( + "`{type_guard_form_name}` function must have a parameter to narrow" + )); + } + continue; + }; + + // For `TypeIs`, check that the narrowed type is assignable to the parameter type. + if let Some(narrowed_ty) = narrowed_type { + let param_ty = first_narrowed_param.annotated_type(); + if !narrowed_ty.is_assignable_to(self.db(), param_ty) { + if let Some(builder) = self + .context + .report_lint(&INVALID_TYPE_GUARD_DEFINITION, returns_expr) + { + builder.into_diagnostic(format_args!( + "Narrowed type `{narrowed}` is not assignable \ + to the declared parameter type `{param}`", + narrowed = narrowed_ty.display(self.db()), + param = param_ty.display(self.db()) + )); + } + } + } + } + } + } + fn infer_region_definition(&mut self, definition: Definition<'db>) { match definition.kind(self.db()) { DefinitionKind::Function(function) => {