[ty] Add diagnostics to validate TypeIs and TypeGuard definitions (#22300)

## Summary

Closes https://github.com/astral-sh/ty/issues/2267.
This commit is contained in:
Charlie Marsh
2026-01-13 20:24:05 -05:00
committed by GitHub
parent ea46426157
commit b5814b91c1
3 changed files with 114 additions and 21 deletions

View File

@@ -797,7 +797,7 @@ class B(A):
pass
class C[T]:
def check(x: object) -> TypeIs[T]:
def check(self, x: object) -> TypeIs[T]:
# this is a bad check, but we only care about it type-checking
return False
@@ -835,7 +835,7 @@ class B(A):
pass
class C[T]:
def check(x: object) -> TypeGuard[T]:
def check(self, x: object) -> TypeGuard[T]:
# this is a bad check, but we only care about it type-checking
return False

View File

@@ -14,8 +14,8 @@ def _(
b: TypeIs[str | int],
c: TypeGuard[bool],
d: TypeIs[tuple[TypeOf[bytes]]],
e: TypeGuard, # error: [invalid-type-form]
f: TypeIs, # error: [invalid-type-form]
e: TypeGuard, # error: [invalid-type-form] "`typing.TypeGuard` requires exactly one argument when used in a type expression"
f: TypeIs, # error: [invalid-type-form] "`typing.TypeIs` requires exactly one argument when used in a type expression"
):
reveal_type(a) # revealed: TypeGuard[str]
reveal_type(b) # revealed: TypeIs[str | int]
@@ -46,12 +46,23 @@ A user-defined type guard must accept at least one positional argument (in addit
for non-static methods).
```pyi
from typing import Any, TypeVar
from typing_extensions import TypeGuard, TypeIs
# TODO: error: [invalid-type-guard-definition]
T = TypeVar("T")
# Multiple parameters are allowed
def is_str_list(val: list[object], allow_empty: bool) -> TypeGuard[list[str]]: ...
def is_set_of(val: set[Any], type: type[T]) -> TypeGuard[set[T]]: ...
def is_two_element_tuple(val: tuple[object, ...], a: str, b: str) -> TypeIs[tuple[str, str]]: ...
# error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
def _() -> TypeGuard[str]: ...
# TODO: error: [invalid-type-guard-definition]
# error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
def _(*args) -> TypeGuard[str]: ...
# error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow"
def _(**kwargs) -> TypeIs[str]: ...
class _:
@@ -63,14 +74,14 @@ class _:
def _(a) -> TypeIs[str]: ...
# errors
def _(self) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition]
def _(self, /, *, a) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition]
def _(self) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
def _(self, /, *, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
@classmethod
def _(cls) -> TypeIs[str]: ... # TODO: error: [invalid-type-guard-definition]
def _(cls) -> TypeIs[str]: ... # error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow"
@classmethod
def _() -> TypeIs[str]: ... # TODO: error: [invalid-type-guard-definition]
def _() -> TypeIs[str]: ... # error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow"
@staticmethod
def _(*, a) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition]
def _(*, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
```
For `TypeIs` functions, the narrowed type must be assignable to the declared type of that parameter,
@@ -86,10 +97,10 @@ def _(a: tuple[object]) -> TypeIs[tuple[str]]: ...
def _(a: str | Any) -> TypeIs[str]: ...
def _(a) -> TypeIs[str]: ...
# TODO: error: [invalid-type-guard-definition]
# error: [invalid-type-guard-definition] "Narrowed type `str` is not assignable to the declared parameter type `int`"
def _(a: int) -> TypeIs[str]: ...
# TODO: error: [invalid-type-guard-definition]
# error: [invalid-type-guard-definition] "Narrowed type `int` is not assignable to the declared parameter type `bool | str`"
def _(a: bool | str) -> TypeIs[int]: ...
```
@@ -107,12 +118,14 @@ class C:
@classmethod
def g(cls, x: object) -> TypeGuard[int]:
return True
# TODO: this could error at definition time
def h(self) -> TypeGuard[str]:
def h(
self,
) -> TypeGuard[str]: # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
return True
# TODO: this could error at definition time
@classmethod
def j(cls) -> TypeGuard[int]:
def j(cls) -> TypeGuard[int]: # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
return True
def _(x: object):
@@ -221,7 +234,7 @@ def g(a: object) -> TypeIs[int]:
return True
def _(d: Any):
if f(): # error: [missing-argument]
if f(): # error: [missing-argument] "No argument provided for required parameter `a` of function `f`"
...
if g(*d):
@@ -230,7 +243,7 @@ def _(d: Any):
if f("foo"): # TODO: error: [invalid-type-guard-call]
...
if g(a=d): # error: [invalid-type-guard-call]
if g(a=d): # error: [invalid-type-guard-call] "Type guard call does not have a target"
...
```

View File

@@ -68,8 +68,8 @@ use crate::types::diagnostic::{
INVALID_GENERIC_ENUM, INVALID_KEY, INVALID_LEGACY_TYPE_VARIABLE, INVALID_METACLASS,
INVALID_NAMED_TUPLE, INVALID_NEWTYPE, INVALID_OVERLOAD, INVALID_PARAMETER_DEFAULT,
INVALID_PARAMSPEC, INVALID_PROTOCOL, INVALID_TYPE_ARGUMENTS, INVALID_TYPE_FORM,
INVALID_TYPE_GUARD_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, INVALID_TYPED_DICT_STATEMENT,
IncompatibleBases, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_ATTRIBUTE,
INVALID_TYPE_GUARD_CALL, INVALID_TYPE_GUARD_DEFINITION, INVALID_TYPE_VARIABLE_CONSTRAINTS,
INVALID_TYPED_DICT_STATEMENT, IncompatibleBases, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_ATTRIBUTE,
POSSIBLY_MISSING_IMPLICIT_CALL, POSSIBLY_MISSING_IMPORT, SUBCLASS_OF_FINAL_CLASS,
TypedDictDeleteErrorKind, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_GLOBAL,
UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNSUPPORTED_DYNAMIC_BASE, UNSUPPORTED_OPERATOR,
@@ -587,6 +587,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if self.db().should_check_file(self.file()) {
self.check_static_class_definitions();
self.check_overloaded_functions(node);
self.check_type_guard_definitions();
}
}
@@ -1452,6 +1453,85 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
/// Check that all type guard function definitions have at least one positional parameter
/// (in addition to `self`/`cls` for methods), and for `TypeIs`, that the narrowed type is
/// assignable to the declared type of that parameter.
fn check_type_guard_definitions(&mut self) {
for (definition, ty) in self.declarations.iter() {
// Only check actual function definitions, not imports.
let DefinitionKind::Function(function_ref) = definition.kind(self.db()) else {
continue;
};
let Some(function) = ty.inner_type().as_function_literal() else {
continue;
};
for overload in function.iter_overloads_and_implementation(self.db()) {
let signature = overload.signature(self.db());
let return_ty = signature.return_ty;
// Check if this is a `TypeIs` or `TypeGuard` return type.
let (type_guard_form_name, narrowed_type) = match return_ty {
Type::TypeIs(type_is) => ("TypeIs", Some(type_is.return_type(self.db()))),
Type::TypeGuard(_) => ("TypeGuard", None),
_ => continue,
};
let function_node = function_ref.node(self.module());
// The return type annotation must exist since we matched `TypeIs`/`TypeGuard`.
let Some(returns_expr) = function_node.returns.as_deref() else {
continue;
};
// Check if this is a non-static method (first parameter is implicit `self`/`cls`).
let is_method = self
.index
.class_definition_of_method(
overload.body_scope(self.db()).file_scope_id(self.db()),
)
.is_some();
let has_implicit_receiver = is_method && !overload.is_staticmethod(self.db());
// Find the first positional parameter to narrow (skip implicit `self`/`cls`).
let positional_params: Vec<_> = signature.parameters().positional().collect();
let first_narrowed_param_index = usize::from(has_implicit_receiver);
let first_narrowed_param = positional_params.get(first_narrowed_param_index);
let Some(first_narrowed_param) = first_narrowed_param else {
if let Some(builder) = self
.context
.report_lint(&INVALID_TYPE_GUARD_DEFINITION, returns_expr)
{
builder.into_diagnostic(format_args!(
"`{type_guard_form_name}` function must have a parameter to narrow"
));
}
continue;
};
// For `TypeIs`, check that the narrowed type is assignable to the parameter type.
if let Some(narrowed_ty) = narrowed_type {
let param_ty = first_narrowed_param.annotated_type();
if !narrowed_ty.is_assignable_to(self.db(), param_ty) {
if let Some(builder) = self
.context
.report_lint(&INVALID_TYPE_GUARD_DEFINITION, returns_expr)
{
builder.into_diagnostic(format_args!(
"Narrowed type `{narrowed}` is not assignable \
to the declared parameter type `{param}`",
narrowed = narrowed_ty.display(self.db()),
param = param_ty.display(self.db())
));
}
}
}
}
}
}
fn infer_region_definition(&mut self, definition: Definition<'db>) {
match definition.kind(self.db()) {
DefinitionKind::Function(function) => {