This commit is contained in:
Shunsuke Shibayama 2025-12-16 14:08:03 +08:00 committed by GitHub
commit 4a02968fdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 657 additions and 27 deletions

View File

@ -194,7 +194,8 @@ static SYMPY: Benchmark = Benchmark::new(
max_dep_date: "2025-06-17",
python_version: PythonVersion::PY312,
},
13030,
// TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably.
70000,
);
static TANJUN: Benchmark = Benchmark::new(

View File

@ -0,0 +1,17 @@
# Regression test for https://github.com/astral-sh/ruff/issues/17371
# panicked in commit d1088545a08aeb57b67ec1e3a7f5141159efefa5
# error message:
# dependency graph cycle when querying ClassType < 'db >::into_callable_(Id(1c00))
try:
class foo[T: bar](object):
pass
bar = foo
except Exception:
bar = lambda: 0
def bar():
pass
@bar()
class bar:
pass

View File

@ -0,0 +1,72 @@
def f(cond: bool):
if cond:
result = ()
result += (f(cond),)
return result
return None
reveal_type(f(True))
def f(cond: bool):
if cond:
result = ()
result += (f(cond),)
return result
return None
def f(cond: bool):
result = None
if cond:
result = ()
result += (f(cond),)
return result
reveal_type(f(True))
def f(cond: bool):
result = None
if cond:
result = [f(cond) for _ in range(1)]
return result
reveal_type(f(True))
class Foo:
def value(self):
return 1
def unwrap(value):
if isinstance(value, Foo):
foo = value
return foo.value()
elif type(value) is tuple:
length = len(value)
if length == 0:
return ()
elif length == 1:
return (unwrap(value[0]),)
else:
result = []
for item in value:
result.append(unwrap(item))
return tuple(result)
else:
raise TypeError()
def descent(x: int, y: int):
if x > y:
y, x = descent(y, x)
return x, y
if x == 1:
return (1, 0)
if y == 1:
return (0, 1)
else:
return descent(x-1, y-1)
def count_set_bits(n):
return 1 + count_set_bits(n & n - 1) if n else 0

View File

@ -79,7 +79,9 @@ def outer_sync(): # `yield` from is only valid syntax inside a synchronous func
a: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions"
): ...
async def baz(): ...
async def baz():
yield
async def outer_async(): # avoid unrelated syntax errors on `yield` and `await`
def _(
a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"

View File

@ -111,7 +111,7 @@ def _(flag: bool):
# error: [call-non-callable] "Object of type `Literal["This is a string literal"]` is not callable"
x = f(3)
reveal_type(x) # revealed: Unknown
reveal_type(x) # revealed: None | Unknown
```
## Union of binding errors
@ -128,7 +128,7 @@ def _(flag: bool):
# error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1"
# error: [too-many-positional-arguments] "Too many positional arguments to function `f2`: expected 0, got 1"
x = f(3)
reveal_type(x) # revealed: Unknown
reveal_type(x) # revealed: None
```
## One not-callable, one wrong argument
@ -146,7 +146,7 @@ def _(flag: bool):
# error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1"
# error: [call-non-callable] "Object of type `C` is not callable"
x = f(3)
reveal_type(x) # revealed: Unknown
reveal_type(x) # revealed: None | Unknown
```
## Union including a special-cased function

View File

@ -125,7 +125,8 @@ match obj:
```py
class C:
def __await__(self): ...
def __await__(self):
yield
# error: [invalid-syntax] "`return` statement outside of a function"
return

View File

@ -295,6 +295,331 @@ def f(cond: bool) -> int:
return 2
```
## Inferred return type
### Free function
If a function's return type is not annotated, it is inferred. The inferred type is the union of all
possible return types.
```py
def f():
return 1
reveal_type(f()) # revealed: Literal[1]
# TODO: should be `def f() -> Literal[1]`
reveal_type(f) # revealed: def f() -> Unknown
def g(cond: bool):
if cond:
return 1
else:
return "a"
reveal_type(g(True)) # revealed: Literal[1, "a"]
# This function implicitly returns `None`.
def h(x: int, y: str):
if x > 10:
return x
elif x > 5:
return y
reveal_type(h(1, "a")) # revealed: int | str | None
lambda_func = lambda: 1
# TODO: lambda function type inference
# Should be `Literal[1]`
reveal_type(lambda_func()) # revealed: Unknown
def generator():
yield 1
yield 2
return None
# TODO: Should be `Generator[Literal[1, 2], Any, None]`
reveal_type(generator()) # revealed: Unknown
async def async_generator():
yield
# TODO: Should be `AsyncGenerator[None, Any]`
reveal_type(async_generator()) # revealed: Unknown
async def coroutine():
return
# TODO: Should be `CoroutineType[Any, Any, None]`
reveal_type(coroutine()) # revealed: Unknown
```
The return type of a recursive function is also inferred. When the return type inference would
diverge, it is truncated and replaced with the special dynamic type `Divergent`.
```toml
[environment]
python-version = "3.12"
```
```py
def fibonacci(n: int):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n - 1) + fibonacci(n - 2)
reveal_type(fibonacci(5)) # revealed: int
def even(n: int):
if n == 0:
return True
else:
return odd(n - 1)
def odd(n: int):
if n == 0:
return False
else:
return even(n - 1)
reveal_type(even(1)) # revealed: bool
reveal_type(odd(1)) # revealed: bool
def repeat_a(n: int):
if n <= 0:
return ""
else:
return repeat_a(n - 1) + "a"
reveal_type(repeat_a(3)) # revealed: str
def divergent(value):
if type(value) is tuple:
return (divergent(value[0]),)
else:
return None
# tuple[tuple[tuple[...] | None] | None] | None => tuple[Divergent] | None
reveal_type(divergent((1,))) # revealed: tuple[Divergent] | None
def call_divergent(x: int):
return (divergent((1, 2, 3)), x)
reveal_type(call_divergent(1)) # revealed: tuple[tuple[Divergent] | None, int]
def list1[T](x: T) -> list[T]:
return [x]
def divergent2(value):
if type(value) is tuple:
return (divergent2(value[0]),)
elif type(value) is list:
return list1(divergent2(value[0]))
else:
return None
reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None
def list_int(x: int):
if x > 0:
return list1(list_int(x - 1))
else:
return list1(x)
# TODO: should be `list[int]`
reveal_type(list_int(1)) # revealed: list[Divergent] | list[int]
def tuple_obj(cond: bool):
if cond:
x = object()
else:
x = tuple_obj(cond)
return (x,)
reveal_type(tuple_obj(True)) # revealed: tuple[object]
def get_non_empty(node):
for child in node.children:
node = get_non_empty(child)
if node is not None:
return node
return None
reveal_type(get_non_empty(None)) # revealed: (Divergent & ~None) | None
def nested_scope():
def inner():
return nested_scope()
return inner()
reveal_type(nested_scope()) # revealed: Never
def eager_nested_scope():
class A:
x = eager_nested_scope()
return A.x
reveal_type(eager_nested_scope()) # revealed: Unknown
class C:
def flip(self) -> "D":
return D()
class D(C):
# TODO invalid override error
def flip(self) -> "C":
return C()
def c_or_d(n: int):
if n == 0:
return D()
else:
return c_or_d(n - 1).flip()
# In fixed-point iteration of the return type inference, the return type is monotonically widened.
# For example, once the return type of `c_or_d` is determined to be `C`,
# it will never be determined to be a subtype `D` in the subsequent iterations.
reveal_type(c_or_d(1)) # revealed: C
```
### Class method
If a method's return type is not annotated, it is also inferred, but the inferred type is a union of
all possible return types and `Unknown`. This is because a method of a class may be overridden by
its subtypes. For example, if the return type of a method is inferred to be `int`, the type the
coder really intended might be `int | None`, in which case it would be impossible for the overridden
method to return `None`.
```py
class C:
def f(self):
return 1
class D(C):
def f(self):
return None
reveal_type(C().f()) # revealed: Literal[1] | Unknown
reveal_type(D().f()) # revealed: None | Literal[1] | Unknown
```
However, in the following cases, `Unknown` is not included in the inferred return type because there
is no ambiguity in the subclass.
- The class or the method is marked as `final`.
```py
from typing import final
@final
class C:
def f(self):
return 1
class D:
@final
def f(self):
return "a"
reveal_type(C().f()) # revealed: Literal[1]
reveal_type(D().f()) # revealed: Literal["a"]
```
- The method overrides the methods of the base classes, and the return types of the base class
methods are known (In this case, the return type of the method is the intersection of the return
types of the methods in the base classes).
```toml
[environment]
python-version = "3.12"
```
```py
from typing import Literal
class C:
def f(self) -> int:
return 1
def g[T](self, x: T) -> T:
return x
def h[T: int](self, x: T) -> T:
return x
def i[T: int](self, x: T) -> list[T]:
return [x]
class D(C):
def f(self):
return 2
# TODO: This should be an invalid-override error.
def g(self, x: int):
return 2
# A strict application of the Liskov Substitution Principle would consider
# this an invalid override because it violates the guarantee that the method returns
# the same type as its input type (any type smaller than int),
# but neither mypy nor pyright will throw an error for this.
def h(self, x: int):
return 2
def i(self, x: int):
return [2]
class E(D):
def f(self):
return 3
reveal_type(C().f()) # revealed: int
reveal_type(D().f()) # revealed: int
reveal_type(E().f()) # revealed: int
reveal_type(C().g(1)) # revealed: Literal[1]
reveal_type(D().g(1)) # revealed: Literal[2] | Unknown
reveal_type(C().h(1)) # revealed: Literal[1]
reveal_type(D().h(1)) # revealed: Literal[2] | Unknown
reveal_type(C().h(True)) # revealed: Literal[True]
reveal_type(D().h(True)) # revealed: Literal[2] | Unknown
reveal_type(C().i(1)) # revealed: list[Literal[1]]
# TODO: better type for list elements
reveal_type(D().i(1)) # revealed: list[Unknown | int] | list[Unknown]
class F:
def f(self) -> Literal[1, 2]:
return 2
class G:
def f(self) -> Literal[2, 3]:
return 2
class H(F, G):
# TODO: should be an invalid-override error
def f(self):
raise NotImplementedError
class I(F, G):
# TODO: should be an invalid-override error
@final
def f(self):
raise NotImplementedError
# We use a return type of `F.f` according to the MRO.
reveal_type(H().f()) # revealed: Literal[1, 2]
reveal_type(I().f()) # revealed: Never
class C2[T]:
def f(self, x: T) -> T:
return x
class D2(C2[int]):
def f(self, x: int):
return x
reveal_type(D2().f(1)) # revealed: int
```
## Invalid return type
<!-- snapshot-diagnostics -->

View File

@ -142,20 +142,25 @@ def _(x: A | B):
reveal_type(x) # revealed: A | B
```
## No narrowing for custom `type` callable
## No special narrowing for custom `type` callable
```py
def type(x: object):
return int
class A: ...
class B: ...
def type(x):
return int
def _(x: A | B):
# The custom `type` function always returns `int`,
# so any branch other than `type(...) is int` is unreachable.
if type(x) is A:
reveal_type(x) # revealed: Never
# And the condition here is always `True` and has no effect on the narrowing of `x`.
elif type(x) is int:
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: A | B
reveal_type(x) # revealed: Never
```
## No narrowing for multiple arguments

View File

@ -1,6 +1,9 @@
use std::ops::Range;
use ruff_db::{files::File, parsed::ParsedModuleRef};
use ruff_db::{
files::File,
parsed::{ParsedModuleRef, parsed_module},
};
use ruff_index::newtype_index;
use ruff_python_ast as ast;
@ -27,6 +30,10 @@ pub struct ScopeId<'db> {
impl get_size2::GetSize for ScopeId<'_> {}
impl<'db> ScopeId<'db> {
pub(crate) fn is_non_lambda_function(self, db: &'db dyn Db) -> bool {
self.node(db).scope_kind().is_non_lambda_function()
}
pub(crate) fn is_annotation(self, db: &'db dyn Db) -> bool {
self.node(db).scope_kind().is_annotation()
}
@ -64,6 +71,18 @@ impl<'db> ScopeId<'db> {
NodeWithScopeKind::GeneratorExpression(_) => "<generator>",
}
}
pub(crate) fn is_coroutine_function(self, db: &'db dyn Db) -> bool {
let module = parsed_module(db, self.file(db)).load(db);
self.node(db)
.as_function()
.is_some_and(|func| func.node(&module).is_async && !self.is_generator_function(db))
}
pub(crate) fn is_generator_function(self, db: &'db dyn Db) -> bool {
let index = semantic_index(db, self.file(db));
self.file_scope_id(db).is_generator_function(index)
}
}
/// ID that uniquely identifies a scope inside of a module.

View File

@ -6454,6 +6454,19 @@ impl<'db> Type<'db> {
}
}
/// Returns the inferred return type of `self` if it is a function literal / bound method.
fn infer_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => {
Some(function_type.infer_return_type(db))
}
Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => {
Some(method_type.infer_return_type(db))
}
_ => None,
}
}
/// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if
/// the arguments are not compatible with the formal parameters.
///
@ -12156,6 +12169,77 @@ impl<'db> BoundMethodType<'db> {
)
}
/// Infers this method scope's types and returns the inferred return type.
#[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)]
pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> {
let scope = self
.function(db)
.literal(db)
.last_definition(db)
.body_scope(db);
let inference = infer_scope_types(db, scope);
inference.infer_return_type(db, scope, Type::BoundMethod(self))
}
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
fn class_definition(self, db: &'db dyn Db) -> Option<Definition<'db>> {
let definition_scope = self.function(db).definition(db).scope(db);
let index = semantic_index(db, definition_scope.file(db));
Some(index.expect_single_definition(definition_scope.node(db).as_class()?))
}
pub(crate) fn is_final(self, db: &'db dyn Db) -> bool {
if self
.function(db)
.has_known_decorator(db, FunctionDecorators::FINAL)
{
return true;
}
let Some(class_ty) = self
.class_definition(db)
.and_then(|class| binding_type(db, class).as_class_literal())
else {
return false;
};
class_ty
.known_function_decorators(db)
.any(|deco| deco == KnownFunction::Final)
}
pub(super) fn base_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?;
let name = self.function(db).name(db);
let base = class
.iter_mro(db)
.nth(1)
.and_then(class_base::ClassBase::into_class)?;
let base_member = base.class_member(db, name, MemberLookupPolicy::default());
if let Place::Defined(Type::FunctionLiteral(base_func), _, _) = base_member.place {
if let [signature] = base_func.signature(db).overloads.as_slice() {
let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| {
let base_method_ty =
base_func.into_bound_method_type(db, Type::instance(db, class));
base_method_ty.infer_return_type(db)
});
if let Some(generic_context) = signature.generic_context.as_ref() {
// If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables.
Some(
unspecialized_return_ty
.apply_specialization(db, generic_context.unknown_specialization(db)),
)
} else {
Some(unspecialized_return_ty)
}
} else {
// TODO: Handle overloaded base methods.
None
}
} else {
None
}
}
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
Self::new(
db,
@ -12233,6 +12317,24 @@ impl<'db> BoundMethodType<'db> {
}
}
fn return_type_cycle_recover<'db>(
db: &'db dyn Db,
cycle: &salsa::Cycle,
previous_return_type: &Type<'db>,
return_type: Type<'db>,
_self: BoundMethodType<'db>,
) -> Type<'db> {
return_type.cycle_normalized(db, *previous_return_type, cycle)
}
fn return_type_cycle_initial<'db>(
_db: &'db dyn Db,
id: salsa::Id,
_method: BoundMethodType<'db>,
) -> Type<'db> {
Type::divergent(id)
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, get_size2::GetSize)]
pub enum CallableTypeKind {
/// Represents regular callable objects.

View File

@ -3650,10 +3650,15 @@ impl<'db> Binding<'db> {
}
}
}
for (keywords_index, keywords_type) in keywords_arguments {
matcher.match_keyword_variadic(db, keywords_index, keywords_type);
}
self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown());
self.return_ty = self.signature.return_ty.unwrap_or_else(|| {
self.callable_type
.infer_return_type(db)
.unwrap_or(Type::unknown())
});
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
self.variadic_argument_matched_to_variadic_parameter =
matcher.variadic_argument_matched_to_variadic_parameter;

View File

@ -85,7 +85,7 @@ use crate::types::{
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType,
NormalizedVisitor, SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type,
TypeContext, TypeMapping, TypeRelation, UnionBuilder, binding_type, definition_expression_type,
infer_definition_types, walk_signature,
infer_definition_types, infer_scope_types, walk_signature,
};
use crate::{Db, FxOrderSet, ModuleName, resolve_module};
@ -1197,6 +1197,32 @@ impl<'db> FunctionType<'db> {
updated_last_definition_signature,
))
}
/// Infers this function scope's types and returns the inferred return type.
#[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)]
pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> {
let scope = self.literal(db).last_definition(db).body_scope(db);
let inference = infer_scope_types(db, scope);
inference.infer_return_type(db, scope, Type::FunctionLiteral(self))
}
}
fn return_type_cycle_recover<'db>(
db: &'db dyn Db,
cycle: &salsa::Cycle,
previous_return_type: &Type<'db>,
return_type: Type<'db>,
_self: FunctionType<'db>,
) -> Type<'db> {
return_type.cycle_normalized(db, *previous_return_type, cycle)
}
fn return_type_cycle_initial<'db>(
_db: &'db dyn Db,
id: salsa::Id,
_function: FunctionType<'db>,
) -> Type<'db> {
Type::divergent(id)
}
/// Evaluate an `isinstance` call. Return `Truthiness::AlwaysTrue` if we can definitely infer that

View File

@ -47,13 +47,13 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map};
use crate::types::diagnostic::TypeCheckDiagnostics;
use crate::types::function::FunctionType;
use crate::types::generics::Specialization;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, declaration_type,
ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, UnionBuilder, declaration_type,
};
use crate::unpack::Unpack;
use builder::TypeInferenceBuilder;
@ -546,6 +546,9 @@ struct ScopeInferenceExtra<'db> {
/// The diagnostics for this region.
diagnostics: TypeCheckDiagnostics,
/// The returned types, if it is a function body.
return_types: Vec<Type<'db>>,
}
impl<'db> ScopeInference<'db> {
@ -605,6 +608,51 @@ impl<'db> ScopeInference<'db> {
extra.string_annotations.contains(&expression.into())
}
/// Returns the inferred return type of this function body (union of all possible return types),
/// or `None` if the region is not a function body.
/// In the case of methods, the return type of the superclass method is further unioned.
/// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`.
pub(crate) fn infer_return_type(
&self,
db: &'db dyn Db,
scope: ScopeId<'db>,
callee_ty: Type<'db>,
) -> Type<'db> {
// TODO: coroutine function type inference
// TODO: generator function type inference
if scope.is_coroutine_function(db) || scope.is_generator_function(db) {
return Type::unknown();
}
let mut union = UnionBuilder::new(db);
if let Some(cycle_recovery) = self.fallback_type() {
union = union.add(cycle_recovery);
}
let Some(extra) = &self.extra else {
unreachable!(
"infer_return_type should only be called on a function body scope inference"
);
};
for return_ty in &extra.return_types {
union = union.add(*return_ty);
}
let use_def = use_def_map(db, scope);
if use_def.can_implicitly_return_none(db) {
union = union.add(Type::none(db));
}
if let Type::BoundMethod(method_ty) = callee_ty {
// If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`.
// If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`.
// On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`.
if !method_ty.is_final(db) {
union = union.add(method_ty.base_return_type(db).unwrap_or(Type::unknown()));
}
}
union.build()
}
}
/// The inferred types for a definition region.

View File

@ -12688,6 +12688,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
pub(super) fn finish_scope(mut self) -> ScopeInference<'db> {
self.infer_region();
let db = self.db();
let Self {
context,
@ -12714,19 +12715,25 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
called_functions: _,
index: _,
region: _,
return_types_and_ranges: _,
return_types_and_ranges,
} = self;
let _ = scope;
let diagnostics = context.finish();
let extra =
(!string_annotations.is_empty() || !diagnostics.is_empty() || cycle_recovery.is_some())
let extra = (!string_annotations.is_empty()
|| !diagnostics.is_empty()
|| cycle_recovery.is_some()
|| scope.is_non_lambda_function(db))
.then(|| {
let return_types = return_types_and_ranges
.into_iter()
.map(|ty_range| ty_range.ty)
.collect();
Box::new(ScopeInferenceExtra {
string_annotations,
cycle_recovery,
diagnostics,
return_types,
})
});