mirror of https://github.com/astral-sh/ruff
Merge 505dcc81ac into 682d29c256
This commit is contained in:
commit
4a02968fdf
|
|
@ -194,7 +194,8 @@ static SYMPY: Benchmark = Benchmark::new(
|
|||
max_dep_date: "2025-06-17",
|
||||
python_version: PythonVersion::PY312,
|
||||
},
|
||||
13030,
|
||||
// TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably.
|
||||
70000,
|
||||
);
|
||||
|
||||
static TANJUN: Benchmark = Benchmark::new(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
# Regression test for https://github.com/astral-sh/ruff/issues/17371
|
||||
# panicked in commit d1088545a08aeb57b67ec1e3a7f5141159efefa5
|
||||
# error message:
|
||||
# dependency graph cycle when querying ClassType < 'db >::into_callable_(Id(1c00))
|
||||
|
||||
try:
|
||||
class foo[T: bar](object):
|
||||
pass
|
||||
bar = foo
|
||||
except Exception:
|
||||
bar = lambda: 0
|
||||
def bar():
|
||||
pass
|
||||
|
||||
@bar()
|
||||
class bar:
|
||||
pass
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
def f(cond: bool):
|
||||
if cond:
|
||||
result = ()
|
||||
result += (f(cond),)
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
reveal_type(f(True))
|
||||
|
||||
def f(cond: bool):
|
||||
if cond:
|
||||
result = ()
|
||||
result += (f(cond),)
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def f(cond: bool):
|
||||
result = None
|
||||
if cond:
|
||||
result = ()
|
||||
result += (f(cond),)
|
||||
|
||||
return result
|
||||
|
||||
reveal_type(f(True))
|
||||
|
||||
def f(cond: bool):
|
||||
result = None
|
||||
if cond:
|
||||
result = [f(cond) for _ in range(1)]
|
||||
|
||||
return result
|
||||
|
||||
reveal_type(f(True))
|
||||
|
||||
class Foo:
|
||||
def value(self):
|
||||
return 1
|
||||
|
||||
def unwrap(value):
|
||||
if isinstance(value, Foo):
|
||||
foo = value
|
||||
return foo.value()
|
||||
elif type(value) is tuple:
|
||||
length = len(value)
|
||||
if length == 0:
|
||||
return ()
|
||||
elif length == 1:
|
||||
return (unwrap(value[0]),)
|
||||
else:
|
||||
result = []
|
||||
for item in value:
|
||||
result.append(unwrap(item))
|
||||
return tuple(result)
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
def descent(x: int, y: int):
|
||||
if x > y:
|
||||
y, x = descent(y, x)
|
||||
return x, y
|
||||
if x == 1:
|
||||
return (1, 0)
|
||||
if y == 1:
|
||||
return (0, 1)
|
||||
else:
|
||||
return descent(x-1, y-1)
|
||||
|
||||
def count_set_bits(n):
|
||||
return 1 + count_set_bits(n & n - 1) if n else 0
|
||||
|
|
@ -79,7 +79,9 @@ def outer_sync(): # `yield` from is only valid syntax inside a synchronous func
|
|||
a: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions"
|
||||
): ...
|
||||
|
||||
async def baz(): ...
|
||||
async def baz():
|
||||
yield
|
||||
|
||||
async def outer_async(): # avoid unrelated syntax errors on `yield` and `await`
|
||||
def _(
|
||||
a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def _(flag: bool):
|
|||
|
||||
# error: [call-non-callable] "Object of type `Literal["This is a string literal"]` is not callable"
|
||||
x = f(3)
|
||||
reveal_type(x) # revealed: Unknown
|
||||
reveal_type(x) # revealed: None | Unknown
|
||||
```
|
||||
|
||||
## Union of binding errors
|
||||
|
|
@ -128,7 +128,7 @@ def _(flag: bool):
|
|||
# error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1"
|
||||
# error: [too-many-positional-arguments] "Too many positional arguments to function `f2`: expected 0, got 1"
|
||||
x = f(3)
|
||||
reveal_type(x) # revealed: Unknown
|
||||
reveal_type(x) # revealed: None
|
||||
```
|
||||
|
||||
## One not-callable, one wrong argument
|
||||
|
|
@ -146,7 +146,7 @@ def _(flag: bool):
|
|||
# error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1"
|
||||
# error: [call-non-callable] "Object of type `C` is not callable"
|
||||
x = f(3)
|
||||
reveal_type(x) # revealed: Unknown
|
||||
reveal_type(x) # revealed: None | Unknown
|
||||
```
|
||||
|
||||
## Union including a special-cased function
|
||||
|
|
|
|||
|
|
@ -125,7 +125,8 @@ match obj:
|
|||
|
||||
```py
|
||||
class C:
|
||||
def __await__(self): ...
|
||||
def __await__(self):
|
||||
yield
|
||||
|
||||
# error: [invalid-syntax] "`return` statement outside of a function"
|
||||
return
|
||||
|
|
|
|||
|
|
@ -295,6 +295,331 @@ def f(cond: bool) -> int:
|
|||
return 2
|
||||
```
|
||||
|
||||
## Inferred return type
|
||||
|
||||
### Free function
|
||||
|
||||
If a function's return type is not annotated, it is inferred. The inferred type is the union of all
|
||||
possible return types.
|
||||
|
||||
```py
|
||||
def f():
|
||||
return 1
|
||||
|
||||
reveal_type(f()) # revealed: Literal[1]
|
||||
# TODO: should be `def f() -> Literal[1]`
|
||||
reveal_type(f) # revealed: def f() -> Unknown
|
||||
|
||||
def g(cond: bool):
|
||||
if cond:
|
||||
return 1
|
||||
else:
|
||||
return "a"
|
||||
|
||||
reveal_type(g(True)) # revealed: Literal[1, "a"]
|
||||
|
||||
# This function implicitly returns `None`.
|
||||
def h(x: int, y: str):
|
||||
if x > 10:
|
||||
return x
|
||||
elif x > 5:
|
||||
return y
|
||||
|
||||
reveal_type(h(1, "a")) # revealed: int | str | None
|
||||
|
||||
lambda_func = lambda: 1
|
||||
# TODO: lambda function type inference
|
||||
# Should be `Literal[1]`
|
||||
reveal_type(lambda_func()) # revealed: Unknown
|
||||
|
||||
def generator():
|
||||
yield 1
|
||||
yield 2
|
||||
return None
|
||||
|
||||
# TODO: Should be `Generator[Literal[1, 2], Any, None]`
|
||||
reveal_type(generator()) # revealed: Unknown
|
||||
|
||||
async def async_generator():
|
||||
yield
|
||||
|
||||
# TODO: Should be `AsyncGenerator[None, Any]`
|
||||
reveal_type(async_generator()) # revealed: Unknown
|
||||
|
||||
async def coroutine():
|
||||
return
|
||||
|
||||
# TODO: Should be `CoroutineType[Any, Any, None]`
|
||||
reveal_type(coroutine()) # revealed: Unknown
|
||||
```
|
||||
|
||||
The return type of a recursive function is also inferred. When the return type inference would
|
||||
diverge, it is truncated and replaced with the special dynamic type `Divergent`.
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
def fibonacci(n: int):
|
||||
if n == 0:
|
||||
return 0
|
||||
elif n == 1:
|
||||
return 1
|
||||
else:
|
||||
return fibonacci(n - 1) + fibonacci(n - 2)
|
||||
|
||||
reveal_type(fibonacci(5)) # revealed: int
|
||||
|
||||
def even(n: int):
|
||||
if n == 0:
|
||||
return True
|
||||
else:
|
||||
return odd(n - 1)
|
||||
|
||||
def odd(n: int):
|
||||
if n == 0:
|
||||
return False
|
||||
else:
|
||||
return even(n - 1)
|
||||
|
||||
reveal_type(even(1)) # revealed: bool
|
||||
reveal_type(odd(1)) # revealed: bool
|
||||
|
||||
def repeat_a(n: int):
|
||||
if n <= 0:
|
||||
return ""
|
||||
else:
|
||||
return repeat_a(n - 1) + "a"
|
||||
|
||||
reveal_type(repeat_a(3)) # revealed: str
|
||||
|
||||
def divergent(value):
|
||||
if type(value) is tuple:
|
||||
return (divergent(value[0]),)
|
||||
else:
|
||||
return None
|
||||
|
||||
# tuple[tuple[tuple[...] | None] | None] | None => tuple[Divergent] | None
|
||||
reveal_type(divergent((1,))) # revealed: tuple[Divergent] | None
|
||||
|
||||
def call_divergent(x: int):
|
||||
return (divergent((1, 2, 3)), x)
|
||||
|
||||
reveal_type(call_divergent(1)) # revealed: tuple[tuple[Divergent] | None, int]
|
||||
|
||||
def list1[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def divergent2(value):
|
||||
if type(value) is tuple:
|
||||
return (divergent2(value[0]),)
|
||||
elif type(value) is list:
|
||||
return list1(divergent2(value[0]))
|
||||
else:
|
||||
return None
|
||||
|
||||
reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None
|
||||
|
||||
def list_int(x: int):
|
||||
if x > 0:
|
||||
return list1(list_int(x - 1))
|
||||
else:
|
||||
return list1(x)
|
||||
|
||||
# TODO: should be `list[int]`
|
||||
reveal_type(list_int(1)) # revealed: list[Divergent] | list[int]
|
||||
|
||||
def tuple_obj(cond: bool):
|
||||
if cond:
|
||||
x = object()
|
||||
else:
|
||||
x = tuple_obj(cond)
|
||||
return (x,)
|
||||
|
||||
reveal_type(tuple_obj(True)) # revealed: tuple[object]
|
||||
|
||||
def get_non_empty(node):
|
||||
for child in node.children:
|
||||
node = get_non_empty(child)
|
||||
if node is not None:
|
||||
return node
|
||||
return None
|
||||
|
||||
reveal_type(get_non_empty(None)) # revealed: (Divergent & ~None) | None
|
||||
|
||||
def nested_scope():
|
||||
def inner():
|
||||
return nested_scope()
|
||||
return inner()
|
||||
|
||||
reveal_type(nested_scope()) # revealed: Never
|
||||
|
||||
def eager_nested_scope():
|
||||
class A:
|
||||
x = eager_nested_scope()
|
||||
|
||||
return A.x
|
||||
|
||||
reveal_type(eager_nested_scope()) # revealed: Unknown
|
||||
|
||||
class C:
|
||||
def flip(self) -> "D":
|
||||
return D()
|
||||
|
||||
class D(C):
|
||||
# TODO invalid override error
|
||||
def flip(self) -> "C":
|
||||
return C()
|
||||
|
||||
def c_or_d(n: int):
|
||||
if n == 0:
|
||||
return D()
|
||||
else:
|
||||
return c_or_d(n - 1).flip()
|
||||
|
||||
# In fixed-point iteration of the return type inference, the return type is monotonically widened.
|
||||
# For example, once the return type of `c_or_d` is determined to be `C`,
|
||||
# it will never be determined to be a subtype `D` in the subsequent iterations.
|
||||
reveal_type(c_or_d(1)) # revealed: C
|
||||
```
|
||||
|
||||
### Class method
|
||||
|
||||
If a method's return type is not annotated, it is also inferred, but the inferred type is a union of
|
||||
all possible return types and `Unknown`. This is because a method of a class may be overridden by
|
||||
its subtypes. For example, if the return type of a method is inferred to be `int`, the type the
|
||||
coder really intended might be `int | None`, in which case it would be impossible for the overridden
|
||||
method to return `None`.
|
||||
|
||||
```py
|
||||
class C:
|
||||
def f(self):
|
||||
return 1
|
||||
|
||||
class D(C):
|
||||
def f(self):
|
||||
return None
|
||||
|
||||
reveal_type(C().f()) # revealed: Literal[1] | Unknown
|
||||
reveal_type(D().f()) # revealed: None | Literal[1] | Unknown
|
||||
```
|
||||
|
||||
However, in the following cases, `Unknown` is not included in the inferred return type because there
|
||||
is no ambiguity in the subclass.
|
||||
|
||||
- The class or the method is marked as `final`.
|
||||
|
||||
```py
|
||||
from typing import final
|
||||
|
||||
@final
|
||||
class C:
|
||||
def f(self):
|
||||
return 1
|
||||
|
||||
class D:
|
||||
@final
|
||||
def f(self):
|
||||
return "a"
|
||||
|
||||
reveal_type(C().f()) # revealed: Literal[1]
|
||||
reveal_type(D().f()) # revealed: Literal["a"]
|
||||
```
|
||||
|
||||
- The method overrides the methods of the base classes, and the return types of the base class
|
||||
methods are known (In this case, the return type of the method is the intersection of the return
|
||||
types of the methods in the base classes).
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
class C:
|
||||
def f(self) -> int:
|
||||
return 1
|
||||
|
||||
def g[T](self, x: T) -> T:
|
||||
return x
|
||||
|
||||
def h[T: int](self, x: T) -> T:
|
||||
return x
|
||||
|
||||
def i[T: int](self, x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
class D(C):
|
||||
def f(self):
|
||||
return 2
|
||||
# TODO: This should be an invalid-override error.
|
||||
def g(self, x: int):
|
||||
return 2
|
||||
# A strict application of the Liskov Substitution Principle would consider
|
||||
# this an invalid override because it violates the guarantee that the method returns
|
||||
# the same type as its input type (any type smaller than int),
|
||||
# but neither mypy nor pyright will throw an error for this.
|
||||
def h(self, x: int):
|
||||
return 2
|
||||
|
||||
def i(self, x: int):
|
||||
return [2]
|
||||
|
||||
class E(D):
|
||||
def f(self):
|
||||
return 3
|
||||
|
||||
reveal_type(C().f()) # revealed: int
|
||||
reveal_type(D().f()) # revealed: int
|
||||
reveal_type(E().f()) # revealed: int
|
||||
reveal_type(C().g(1)) # revealed: Literal[1]
|
||||
reveal_type(D().g(1)) # revealed: Literal[2] | Unknown
|
||||
reveal_type(C().h(1)) # revealed: Literal[1]
|
||||
reveal_type(D().h(1)) # revealed: Literal[2] | Unknown
|
||||
reveal_type(C().h(True)) # revealed: Literal[True]
|
||||
reveal_type(D().h(True)) # revealed: Literal[2] | Unknown
|
||||
reveal_type(C().i(1)) # revealed: list[Literal[1]]
|
||||
# TODO: better type for list elements
|
||||
reveal_type(D().i(1)) # revealed: list[Unknown | int] | list[Unknown]
|
||||
|
||||
class F:
|
||||
def f(self) -> Literal[1, 2]:
|
||||
return 2
|
||||
|
||||
class G:
|
||||
def f(self) -> Literal[2, 3]:
|
||||
return 2
|
||||
|
||||
class H(F, G):
|
||||
# TODO: should be an invalid-override error
|
||||
def f(self):
|
||||
raise NotImplementedError
|
||||
|
||||
class I(F, G):
|
||||
# TODO: should be an invalid-override error
|
||||
@final
|
||||
def f(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# We use a return type of `F.f` according to the MRO.
|
||||
reveal_type(H().f()) # revealed: Literal[1, 2]
|
||||
reveal_type(I().f()) # revealed: Never
|
||||
|
||||
class C2[T]:
|
||||
def f(self, x: T) -> T:
|
||||
return x
|
||||
|
||||
class D2(C2[int]):
|
||||
def f(self, x: int):
|
||||
return x
|
||||
|
||||
reveal_type(D2().f(1)) # revealed: int
|
||||
```
|
||||
|
||||
## Invalid return type
|
||||
|
||||
<!-- snapshot-diagnostics -->
|
||||
|
|
|
|||
|
|
@ -142,20 +142,25 @@ def _(x: A | B):
|
|||
reveal_type(x) # revealed: A | B
|
||||
```
|
||||
|
||||
## No narrowing for custom `type` callable
|
||||
## No special narrowing for custom `type` callable
|
||||
|
||||
```py
|
||||
def type(x: object):
|
||||
return int
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
|
||||
def type(x):
|
||||
return int
|
||||
|
||||
def _(x: A | B):
|
||||
# The custom `type` function always returns `int`,
|
||||
# so any branch other than `type(...) is int` is unreachable.
|
||||
if type(x) is A:
|
||||
reveal_type(x) # revealed: Never
|
||||
# And the condition here is always `True` and has no effect on the narrowing of `x`.
|
||||
elif type(x) is int:
|
||||
reveal_type(x) # revealed: A | B
|
||||
else:
|
||||
reveal_type(x) # revealed: A | B
|
||||
reveal_type(x) # revealed: Never
|
||||
```
|
||||
|
||||
## No narrowing for multiple arguments
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use ruff_db::{files::File, parsed::ParsedModuleRef};
|
||||
use ruff_db::{
|
||||
files::File,
|
||||
parsed::{ParsedModuleRef, parsed_module},
|
||||
};
|
||||
use ruff_index::newtype_index;
|
||||
use ruff_python_ast as ast;
|
||||
|
||||
|
|
@ -27,6 +30,10 @@ pub struct ScopeId<'db> {
|
|||
impl get_size2::GetSize for ScopeId<'_> {}
|
||||
|
||||
impl<'db> ScopeId<'db> {
|
||||
pub(crate) fn is_non_lambda_function(self, db: &'db dyn Db) -> bool {
|
||||
self.node(db).scope_kind().is_non_lambda_function()
|
||||
}
|
||||
|
||||
pub(crate) fn is_annotation(self, db: &'db dyn Db) -> bool {
|
||||
self.node(db).scope_kind().is_annotation()
|
||||
}
|
||||
|
|
@ -64,6 +71,18 @@ impl<'db> ScopeId<'db> {
|
|||
NodeWithScopeKind::GeneratorExpression(_) => "<generator>",
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_coroutine_function(self, db: &'db dyn Db) -> bool {
|
||||
let module = parsed_module(db, self.file(db)).load(db);
|
||||
self.node(db)
|
||||
.as_function()
|
||||
.is_some_and(|func| func.node(&module).is_async && !self.is_generator_function(db))
|
||||
}
|
||||
|
||||
pub(crate) fn is_generator_function(self, db: &'db dyn Db) -> bool {
|
||||
let index = semantic_index(db, self.file(db));
|
||||
self.file_scope_id(db).is_generator_function(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// ID that uniquely identifies a scope inside of a module.
|
||||
|
|
|
|||
|
|
@ -6454,6 +6454,19 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the inferred return type of `self` if it is a function literal / bound method.
|
||||
fn infer_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
|
||||
match self {
|
||||
Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => {
|
||||
Some(function_type.infer_return_type(db))
|
||||
}
|
||||
Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => {
|
||||
Some(method_type.infer_return_type(db))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if
|
||||
/// the arguments are not compatible with the formal parameters.
|
||||
///
|
||||
|
|
@ -12156,6 +12169,77 @@ impl<'db> BoundMethodType<'db> {
|
|||
)
|
||||
}
|
||||
|
||||
/// Infers this method scope's types and returns the inferred return type.
|
||||
#[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)]
|
||||
pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||
let scope = self
|
||||
.function(db)
|
||||
.literal(db)
|
||||
.last_definition(db)
|
||||
.body_scope(db);
|
||||
let inference = infer_scope_types(db, scope);
|
||||
inference.infer_return_type(db, scope, Type::BoundMethod(self))
|
||||
}
|
||||
|
||||
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
|
||||
fn class_definition(self, db: &'db dyn Db) -> Option<Definition<'db>> {
|
||||
let definition_scope = self.function(db).definition(db).scope(db);
|
||||
let index = semantic_index(db, definition_scope.file(db));
|
||||
Some(index.expect_single_definition(definition_scope.node(db).as_class()?))
|
||||
}
|
||||
|
||||
pub(crate) fn is_final(self, db: &'db dyn Db) -> bool {
|
||||
if self
|
||||
.function(db)
|
||||
.has_known_decorator(db, FunctionDecorators::FINAL)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
let Some(class_ty) = self
|
||||
.class_definition(db)
|
||||
.and_then(|class| binding_type(db, class).as_class_literal())
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
class_ty
|
||||
.known_function_decorators(db)
|
||||
.any(|deco| deco == KnownFunction::Final)
|
||||
}
|
||||
|
||||
pub(super) fn base_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
|
||||
let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?;
|
||||
let name = self.function(db).name(db);
|
||||
|
||||
let base = class
|
||||
.iter_mro(db)
|
||||
.nth(1)
|
||||
.and_then(class_base::ClassBase::into_class)?;
|
||||
let base_member = base.class_member(db, name, MemberLookupPolicy::default());
|
||||
if let Place::Defined(Type::FunctionLiteral(base_func), _, _) = base_member.place {
|
||||
if let [signature] = base_func.signature(db).overloads.as_slice() {
|
||||
let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| {
|
||||
let base_method_ty =
|
||||
base_func.into_bound_method_type(db, Type::instance(db, class));
|
||||
base_method_ty.infer_return_type(db)
|
||||
});
|
||||
if let Some(generic_context) = signature.generic_context.as_ref() {
|
||||
// If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables.
|
||||
Some(
|
||||
unspecialized_return_ty
|
||||
.apply_specialization(db, generic_context.unknown_specialization(db)),
|
||||
)
|
||||
} else {
|
||||
Some(unspecialized_return_ty)
|
||||
}
|
||||
} else {
|
||||
// TODO: Handle overloaded base methods.
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
|
||||
Self::new(
|
||||
db,
|
||||
|
|
@ -12233,6 +12317,24 @@ impl<'db> BoundMethodType<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn return_type_cycle_recover<'db>(
|
||||
db: &'db dyn Db,
|
||||
cycle: &salsa::Cycle,
|
||||
previous_return_type: &Type<'db>,
|
||||
return_type: Type<'db>,
|
||||
_self: BoundMethodType<'db>,
|
||||
) -> Type<'db> {
|
||||
return_type.cycle_normalized(db, *previous_return_type, cycle)
|
||||
}
|
||||
|
||||
fn return_type_cycle_initial<'db>(
|
||||
_db: &'db dyn Db,
|
||||
id: salsa::Id,
|
||||
_method: BoundMethodType<'db>,
|
||||
) -> Type<'db> {
|
||||
Type::divergent(id)
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, get_size2::GetSize)]
|
||||
pub enum CallableTypeKind {
|
||||
/// Represents regular callable objects.
|
||||
|
|
|
|||
|
|
@ -3650,10 +3650,15 @@ impl<'db> Binding<'db> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (keywords_index, keywords_type) in keywords_arguments {
|
||||
matcher.match_keyword_variadic(db, keywords_index, keywords_type);
|
||||
}
|
||||
self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown());
|
||||
self.return_ty = self.signature.return_ty.unwrap_or_else(|| {
|
||||
self.callable_type
|
||||
.infer_return_type(db)
|
||||
.unwrap_or(Type::unknown())
|
||||
});
|
||||
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
|
||||
self.variadic_argument_matched_to_variadic_parameter =
|
||||
matcher.variadic_argument_matched_to_variadic_parameter;
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ use crate::types::{
|
|||
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType,
|
||||
NormalizedVisitor, SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type,
|
||||
TypeContext, TypeMapping, TypeRelation, UnionBuilder, binding_type, definition_expression_type,
|
||||
infer_definition_types, walk_signature,
|
||||
infer_definition_types, infer_scope_types, walk_signature,
|
||||
};
|
||||
use crate::{Db, FxOrderSet, ModuleName, resolve_module};
|
||||
|
||||
|
|
@ -1197,6 +1197,32 @@ impl<'db> FunctionType<'db> {
|
|||
updated_last_definition_signature,
|
||||
))
|
||||
}
|
||||
|
||||
/// Infers this function scope's types and returns the inferred return type.
|
||||
#[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)]
|
||||
pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||
let scope = self.literal(db).last_definition(db).body_scope(db);
|
||||
let inference = infer_scope_types(db, scope);
|
||||
inference.infer_return_type(db, scope, Type::FunctionLiteral(self))
|
||||
}
|
||||
}
|
||||
|
||||
fn return_type_cycle_recover<'db>(
|
||||
db: &'db dyn Db,
|
||||
cycle: &salsa::Cycle,
|
||||
previous_return_type: &Type<'db>,
|
||||
return_type: Type<'db>,
|
||||
_self: FunctionType<'db>,
|
||||
) -> Type<'db> {
|
||||
return_type.cycle_normalized(db, *previous_return_type, cycle)
|
||||
}
|
||||
|
||||
fn return_type_cycle_initial<'db>(
|
||||
_db: &'db dyn Db,
|
||||
id: salsa::Id,
|
||||
_function: FunctionType<'db>,
|
||||
) -> Type<'db> {
|
||||
Type::divergent(id)
|
||||
}
|
||||
|
||||
/// Evaluate an `isinstance` call. Return `Truthiness::AlwaysTrue` if we can definitely infer that
|
||||
|
|
|
|||
|
|
@ -47,13 +47,13 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
|
|||
use crate::semantic_index::definition::Definition;
|
||||
use crate::semantic_index::expression::Expression;
|
||||
use crate::semantic_index::scope::ScopeId;
|
||||
use crate::semantic_index::{SemanticIndex, semantic_index};
|
||||
use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map};
|
||||
use crate::types::diagnostic::TypeCheckDiagnostics;
|
||||
use crate::types::function::FunctionType;
|
||||
use crate::types::generics::Specialization;
|
||||
use crate::types::unpacker::{UnpackResult, Unpacker};
|
||||
use crate::types::{
|
||||
ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, declaration_type,
|
||||
ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, UnionBuilder, declaration_type,
|
||||
};
|
||||
use crate::unpack::Unpack;
|
||||
use builder::TypeInferenceBuilder;
|
||||
|
|
@ -546,6 +546,9 @@ struct ScopeInferenceExtra<'db> {
|
|||
|
||||
/// The diagnostics for this region.
|
||||
diagnostics: TypeCheckDiagnostics,
|
||||
|
||||
/// The returned types, if it is a function body.
|
||||
return_types: Vec<Type<'db>>,
|
||||
}
|
||||
|
||||
impl<'db> ScopeInference<'db> {
|
||||
|
|
@ -605,6 +608,51 @@ impl<'db> ScopeInference<'db> {
|
|||
|
||||
extra.string_annotations.contains(&expression.into())
|
||||
}
|
||||
|
||||
/// Returns the inferred return type of this function body (union of all possible return types),
|
||||
/// or `None` if the region is not a function body.
|
||||
/// In the case of methods, the return type of the superclass method is further unioned.
|
||||
/// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`.
|
||||
pub(crate) fn infer_return_type(
|
||||
&self,
|
||||
db: &'db dyn Db,
|
||||
scope: ScopeId<'db>,
|
||||
callee_ty: Type<'db>,
|
||||
) -> Type<'db> {
|
||||
// TODO: coroutine function type inference
|
||||
// TODO: generator function type inference
|
||||
if scope.is_coroutine_function(db) || scope.is_generator_function(db) {
|
||||
return Type::unknown();
|
||||
}
|
||||
|
||||
let mut union = UnionBuilder::new(db);
|
||||
if let Some(cycle_recovery) = self.fallback_type() {
|
||||
union = union.add(cycle_recovery);
|
||||
}
|
||||
|
||||
let Some(extra) = &self.extra else {
|
||||
unreachable!(
|
||||
"infer_return_type should only be called on a function body scope inference"
|
||||
);
|
||||
};
|
||||
for return_ty in &extra.return_types {
|
||||
union = union.add(*return_ty);
|
||||
}
|
||||
let use_def = use_def_map(db, scope);
|
||||
if use_def.can_implicitly_return_none(db) {
|
||||
union = union.add(Type::none(db));
|
||||
}
|
||||
if let Type::BoundMethod(method_ty) = callee_ty {
|
||||
// If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`.
|
||||
// If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`.
|
||||
// On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`.
|
||||
if !method_ty.is_final(db) {
|
||||
union = union.add(method_ty.base_return_type(db).unwrap_or(Type::unknown()));
|
||||
}
|
||||
}
|
||||
|
||||
union.build()
|
||||
}
|
||||
}
|
||||
|
||||
/// The inferred types for a definition region.
|
||||
|
|
|
|||
|
|
@ -12688,6 +12688,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
pub(super) fn finish_scope(mut self) -> ScopeInference<'db> {
|
||||
self.infer_region();
|
||||
let db = self.db();
|
||||
|
||||
let Self {
|
||||
context,
|
||||
|
|
@ -12714,19 +12715,25 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
called_functions: _,
|
||||
index: _,
|
||||
region: _,
|
||||
return_types_and_ranges: _,
|
||||
return_types_and_ranges,
|
||||
} = self;
|
||||
|
||||
let _ = scope;
|
||||
let diagnostics = context.finish();
|
||||
|
||||
let extra =
|
||||
(!string_annotations.is_empty() || !diagnostics.is_empty() || cycle_recovery.is_some())
|
||||
let extra = (!string_annotations.is_empty()
|
||||
|| !diagnostics.is_empty()
|
||||
|| cycle_recovery.is_some()
|
||||
|| scope.is_non_lambda_function(db))
|
||||
.then(|| {
|
||||
let return_types = return_types_and_ranges
|
||||
.into_iter()
|
||||
.map(|ty_range| ty_range.ty)
|
||||
.collect();
|
||||
Box::new(ScopeInferenceExtra {
|
||||
string_annotations,
|
||||
cycle_recovery,
|
||||
diagnostics,
|
||||
return_types,
|
||||
})
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue