mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 21:40:51 -05:00
[ty] Add diagnostics to validate TypeIs and TypeGuard definitions (#22300)
## Summary Closes https://github.com/astral-sh/ty/issues/2267.
This commit is contained in:
@@ -797,7 +797,7 @@ class B(A):
|
||||
pass
|
||||
|
||||
class C[T]:
|
||||
def check(x: object) -> TypeIs[T]:
|
||||
def check(self, x: object) -> TypeIs[T]:
|
||||
# this is a bad check, but we only care about it type-checking
|
||||
return False
|
||||
|
||||
@@ -835,7 +835,7 @@ class B(A):
|
||||
pass
|
||||
|
||||
class C[T]:
|
||||
def check(x: object) -> TypeGuard[T]:
|
||||
def check(self, x: object) -> TypeGuard[T]:
|
||||
# this is a bad check, but we only care about it type-checking
|
||||
return False
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ def _(
|
||||
b: TypeIs[str | int],
|
||||
c: TypeGuard[bool],
|
||||
d: TypeIs[tuple[TypeOf[bytes]]],
|
||||
e: TypeGuard, # error: [invalid-type-form]
|
||||
f: TypeIs, # error: [invalid-type-form]
|
||||
e: TypeGuard, # error: [invalid-type-form] "`typing.TypeGuard` requires exactly one argument when used in a type expression"
|
||||
f: TypeIs, # error: [invalid-type-form] "`typing.TypeIs` requires exactly one argument when used in a type expression"
|
||||
):
|
||||
reveal_type(a) # revealed: TypeGuard[str]
|
||||
reveal_type(b) # revealed: TypeIs[str | int]
|
||||
@@ -46,12 +46,23 @@ A user-defined type guard must accept at least one positional argument (in addit
|
||||
for non-static methods).
|
||||
|
||||
```pyi
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import TypeGuard, TypeIs
|
||||
|
||||
# TODO: error: [invalid-type-guard-definition]
|
||||
T = TypeVar("T")
|
||||
|
||||
# Multiple parameters are allowed
|
||||
def is_str_list(val: list[object], allow_empty: bool) -> TypeGuard[list[str]]: ...
|
||||
def is_set_of(val: set[Any], type: type[T]) -> TypeGuard[set[T]]: ...
|
||||
def is_two_element_tuple(val: tuple[object, ...], a: str, b: str) -> TypeIs[tuple[str, str]]: ...
|
||||
|
||||
# error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
def _() -> TypeGuard[str]: ...
|
||||
|
||||
# TODO: error: [invalid-type-guard-definition]
|
||||
# error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
def _(*args) -> TypeGuard[str]: ...
|
||||
|
||||
# error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow"
|
||||
def _(**kwargs) -> TypeIs[str]: ...
|
||||
|
||||
class _:
|
||||
@@ -63,14 +74,14 @@ class _:
|
||||
def _(a) -> TypeIs[str]: ...
|
||||
|
||||
# errors
|
||||
def _(self) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition]
|
||||
def _(self, /, *, a) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition]
|
||||
def _(self) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
def _(self, /, *, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
@classmethod
|
||||
def _(cls) -> TypeIs[str]: ... # TODO: error: [invalid-type-guard-definition]
|
||||
def _(cls) -> TypeIs[str]: ... # error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow"
|
||||
@classmethod
|
||||
def _() -> TypeIs[str]: ... # TODO: error: [invalid-type-guard-definition]
|
||||
def _() -> TypeIs[str]: ... # error: [invalid-type-guard-definition] "`TypeIs` function must have a parameter to narrow"
|
||||
@staticmethod
|
||||
def _(*, a) -> TypeGuard[str]: ... # TODO: error: [invalid-type-guard-definition]
|
||||
def _(*, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
```
|
||||
|
||||
For `TypeIs` functions, the narrowed type must be assignable to the declared type of that parameter,
|
||||
@@ -86,10 +97,10 @@ def _(a: tuple[object]) -> TypeIs[tuple[str]]: ...
|
||||
def _(a: str | Any) -> TypeIs[str]: ...
|
||||
def _(a) -> TypeIs[str]: ...
|
||||
|
||||
# TODO: error: [invalid-type-guard-definition]
|
||||
# error: [invalid-type-guard-definition] "Narrowed type `str` is not assignable to the declared parameter type `int`"
|
||||
def _(a: int) -> TypeIs[str]: ...
|
||||
|
||||
# TODO: error: [invalid-type-guard-definition]
|
||||
# error: [invalid-type-guard-definition] "Narrowed type `int` is not assignable to the declared parameter type `bool | str`"
|
||||
def _(a: bool | str) -> TypeIs[int]: ...
|
||||
```
|
||||
|
||||
@@ -107,12 +118,14 @@ class C:
|
||||
@classmethod
|
||||
def g(cls, x: object) -> TypeGuard[int]:
|
||||
return True
|
||||
# TODO: this could error at definition time
|
||||
def h(self) -> TypeGuard[str]:
|
||||
|
||||
def h(
|
||||
self,
|
||||
) -> TypeGuard[str]: # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
return True
|
||||
# TODO: this could error at definition time
|
||||
|
||||
@classmethod
|
||||
def j(cls) -> TypeGuard[int]:
|
||||
def j(cls) -> TypeGuard[int]: # error: [invalid-type-guard-definition] "`TypeGuard` function must have a parameter to narrow"
|
||||
return True
|
||||
|
||||
def _(x: object):
|
||||
@@ -221,7 +234,7 @@ def g(a: object) -> TypeIs[int]:
|
||||
return True
|
||||
|
||||
def _(d: Any):
|
||||
if f(): # error: [missing-argument]
|
||||
if f(): # error: [missing-argument] "No argument provided for required parameter `a` of function `f`"
|
||||
...
|
||||
|
||||
if g(*d):
|
||||
@@ -230,7 +243,7 @@ def _(d: Any):
|
||||
if f("foo"): # TODO: error: [invalid-type-guard-call]
|
||||
...
|
||||
|
||||
if g(a=d): # error: [invalid-type-guard-call]
|
||||
if g(a=d): # error: [invalid-type-guard-call] "Type guard call does not have a target"
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@@ -68,8 +68,8 @@ use crate::types::diagnostic::{
|
||||
INVALID_GENERIC_ENUM, INVALID_KEY, INVALID_LEGACY_TYPE_VARIABLE, INVALID_METACLASS,
|
||||
INVALID_NAMED_TUPLE, INVALID_NEWTYPE, INVALID_OVERLOAD, INVALID_PARAMETER_DEFAULT,
|
||||
INVALID_PARAMSPEC, INVALID_PROTOCOL, INVALID_TYPE_ARGUMENTS, INVALID_TYPE_FORM,
|
||||
INVALID_TYPE_GUARD_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, INVALID_TYPED_DICT_STATEMENT,
|
||||
IncompatibleBases, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_ATTRIBUTE,
|
||||
INVALID_TYPE_GUARD_CALL, INVALID_TYPE_GUARD_DEFINITION, INVALID_TYPE_VARIABLE_CONSTRAINTS,
|
||||
INVALID_TYPED_DICT_STATEMENT, IncompatibleBases, NOT_SUBSCRIPTABLE, POSSIBLY_MISSING_ATTRIBUTE,
|
||||
POSSIBLY_MISSING_IMPLICIT_CALL, POSSIBLY_MISSING_IMPORT, SUBCLASS_OF_FINAL_CLASS,
|
||||
TypedDictDeleteErrorKind, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_GLOBAL,
|
||||
UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNSUPPORTED_DYNAMIC_BASE, UNSUPPORTED_OPERATOR,
|
||||
@@ -587,6 +587,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
if self.db().should_check_file(self.file()) {
|
||||
self.check_static_class_definitions();
|
||||
self.check_overloaded_functions(node);
|
||||
self.check_type_guard_definitions();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1452,6 +1453,85 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that all type guard function definitions have at least one positional parameter
|
||||
/// (in addition to `self`/`cls` for methods), and for `TypeIs`, that the narrowed type is
|
||||
/// assignable to the declared type of that parameter.
|
||||
fn check_type_guard_definitions(&mut self) {
|
||||
for (definition, ty) in self.declarations.iter() {
|
||||
// Only check actual function definitions, not imports.
|
||||
let DefinitionKind::Function(function_ref) = definition.kind(self.db()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(function) = ty.inner_type().as_function_literal() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for overload in function.iter_overloads_and_implementation(self.db()) {
|
||||
let signature = overload.signature(self.db());
|
||||
let return_ty = signature.return_ty;
|
||||
|
||||
// Check if this is a `TypeIs` or `TypeGuard` return type.
|
||||
let (type_guard_form_name, narrowed_type) = match return_ty {
|
||||
Type::TypeIs(type_is) => ("TypeIs", Some(type_is.return_type(self.db()))),
|
||||
Type::TypeGuard(_) => ("TypeGuard", None),
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let function_node = function_ref.node(self.module());
|
||||
|
||||
// The return type annotation must exist since we matched `TypeIs`/`TypeGuard`.
|
||||
let Some(returns_expr) = function_node.returns.as_deref() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Check if this is a non-static method (first parameter is implicit `self`/`cls`).
|
||||
let is_method = self
|
||||
.index
|
||||
.class_definition_of_method(
|
||||
overload.body_scope(self.db()).file_scope_id(self.db()),
|
||||
)
|
||||
.is_some();
|
||||
let has_implicit_receiver = is_method && !overload.is_staticmethod(self.db());
|
||||
|
||||
// Find the first positional parameter to narrow (skip implicit `self`/`cls`).
|
||||
let positional_params: Vec<_> = signature.parameters().positional().collect();
|
||||
let first_narrowed_param_index = usize::from(has_implicit_receiver);
|
||||
let first_narrowed_param = positional_params.get(first_narrowed_param_index);
|
||||
|
||||
let Some(first_narrowed_param) = first_narrowed_param else {
|
||||
if let Some(builder) = self
|
||||
.context
|
||||
.report_lint(&INVALID_TYPE_GUARD_DEFINITION, returns_expr)
|
||||
{
|
||||
builder.into_diagnostic(format_args!(
|
||||
"`{type_guard_form_name}` function must have a parameter to narrow"
|
||||
));
|
||||
}
|
||||
continue;
|
||||
};
|
||||
|
||||
// For `TypeIs`, check that the narrowed type is assignable to the parameter type.
|
||||
if let Some(narrowed_ty) = narrowed_type {
|
||||
let param_ty = first_narrowed_param.annotated_type();
|
||||
if !narrowed_ty.is_assignable_to(self.db(), param_ty) {
|
||||
if let Some(builder) = self
|
||||
.context
|
||||
.report_lint(&INVALID_TYPE_GUARD_DEFINITION, returns_expr)
|
||||
{
|
||||
builder.into_diagnostic(format_args!(
|
||||
"Narrowed type `{narrowed}` is not assignable \
|
||||
to the declared parameter type `{param}`",
|
||||
narrowed = narrowed_ty.display(self.db()),
|
||||
param = param_ty.display(self.db())
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_region_definition(&mut self, definition: Definition<'db>) {
|
||||
match definition.kind(self.db()) {
|
||||
DefinitionKind::Function(function) => {
|
||||
|
||||
Reference in New Issue
Block a user