diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index 7a0f9a5848..f3ef925e19 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -194,7 +194,8 @@ static SYMPY: Benchmark = Benchmark::new( max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 13030, + // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. + 70000, ); static TANJUN: Benchmark = Benchmark::new( diff --git a/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py new file mode 100644 index 0000000000..ce4cd6a795 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py @@ -0,0 +1,17 @@ +# Regression test for https://github.com/astral-sh/ruff/issues/17371 +# panicked in commit d1088545a08aeb57b67ec1e3a7f5141159efefa5 +# error message: +# dependency graph cycle when querying ClassType < 'db >::into_callable_(Id(1c00)) + +try: + class foo[T: bar](object): + pass + bar = foo +except Exception: + bar = lambda: 0 +def bar(): + pass + +@bar() +class bar: + pass diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py new file mode 100644 index 0000000000..1ef6726cf2 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -0,0 +1,72 @@ +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +reveal_type(f(True)) + +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +def f(cond: bool): + result = None + if cond: + result = () + result += (f(cond),) + + return result + +reveal_type(f(True)) + +def f(cond: bool): + result = None + if cond: + result = [f(cond) for _ in range(1)] + + return result + +reveal_type(f(True)) + +class Foo: + def value(self): + return 1 + +def unwrap(value): + if isinstance(value, Foo): + foo = value + return foo.value() + elif type(value) is tuple: + length = len(value) + if length == 0: + return () + elif length == 1: + return (unwrap(value[0]),) + else: + result = [] + for item in value: + result.append(unwrap(item)) + return tuple(result) + else: + raise TypeError() + +def descent(x: int, y: int): + if x > y: + y, x = descent(y, x) + return x, y + if x == 1: + return (1, 0) + if y == 1: + return (0, 1) + else: + return descent(x-1, y-1) + +def count_set_bits(n): + return 1 + count_set_bits(n & n - 1) if n else 0 diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md index de5f16dcbb..cb9f65c875 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md @@ -79,7 +79,9 @@ def outer_sync(): # `yield` from is only valid syntax inside a synchronous func a: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions" ): ... -async def baz(): ... +async def baz(): + yield + async def outer_async(): # avoid unrelated syntax errors on `yield` and `await` def _( a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression" diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index 09943c0801..0b0095a392 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -111,7 +111,7 @@ def _(flag: bool): # error: [call-non-callable] "Object of type `Literal["This is a string literal"]` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union of binding errors @@ -128,7 +128,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [too-many-positional-arguments] "Too many positional arguments to function `f2`: expected 0, got 1" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None ``` ## One not-callable, one wrong argument @@ -146,7 +146,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [call-non-callable] "Object of type `C` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union including a special-cased function diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index b4e0b1ae24..5c2c958a15 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -125,7 +125,8 @@ match obj: ```py class C: - def __await__(self): ... + def __await__(self): + yield # error: [invalid-syntax] "`return` statement outside of a function" return diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 81ccb339e7..3c6fbc839e 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -295,6 +295,331 @@ def f(cond: bool) -> int: return 2 ``` +## Inferred return type + +### Free function + +If a function's return type is not annotated, it is inferred. The inferred type is the union of all +possible return types. + +```py +def f(): + return 1 + +reveal_type(f()) # revealed: Literal[1] +# TODO: should be `def f() -> Literal[1]` +reveal_type(f) # revealed: def f() -> Unknown + +def g(cond: bool): + if cond: + return 1 + else: + return "a" + +reveal_type(g(True)) # revealed: Literal[1, "a"] + +# This function implicitly returns `None`. +def h(x: int, y: str): + if x > 10: + return x + elif x > 5: + return y + +reveal_type(h(1, "a")) # revealed: int | str | None + +lambda_func = lambda: 1 +# TODO: lambda function type inference +# Should be `Literal[1]` +reveal_type(lambda_func()) # revealed: Unknown + +def generator(): + yield 1 + yield 2 + return None + +# TODO: Should be `Generator[Literal[1, 2], Any, None]` +reveal_type(generator()) # revealed: Unknown + +async def async_generator(): + yield + +# TODO: Should be `AsyncGenerator[None, Any]` +reveal_type(async_generator()) # revealed: Unknown + +async def coroutine(): + return + +# TODO: Should be `CoroutineType[Any, Any, None]` +reveal_type(coroutine()) # revealed: Unknown +``` + +The return type of a recursive function is also inferred. When the return type inference would +diverge, it is truncated and replaced with the special dynamic type `Divergent`. + +```toml +[environment] +python-version = "3.12" +``` + +```py +def fibonacci(n: int): + if n == 0: + return 0 + elif n == 1: + return 1 + else: + return fibonacci(n - 1) + fibonacci(n - 2) + +reveal_type(fibonacci(5)) # revealed: int + +def even(n: int): + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int): + if n == 0: + return False + else: + return even(n - 1) + +reveal_type(even(1)) # revealed: bool +reveal_type(odd(1)) # revealed: bool + +def repeat_a(n: int): + if n <= 0: + return "" + else: + return repeat_a(n - 1) + "a" + +reveal_type(repeat_a(3)) # revealed: str + +def divergent(value): + if type(value) is tuple: + return (divergent(value[0]),) + else: + return None + +# tuple[tuple[tuple[...] | None] | None] | None => tuple[Divergent] | None +reveal_type(divergent((1,))) # revealed: tuple[Divergent] | None + +def call_divergent(x: int): + return (divergent((1, 2, 3)), x) + +reveal_type(call_divergent(1)) # revealed: tuple[tuple[Divergent] | None, int] + +def list1[T](x: T) -> list[T]: + return [x] + +def divergent2(value): + if type(value) is tuple: + return (divergent2(value[0]),) + elif type(value) is list: + return list1(divergent2(value[0])) + else: + return None + +reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None + +def list_int(x: int): + if x > 0: + return list1(list_int(x - 1)) + else: + return list1(x) + +# TODO: should be `list[int]` +reveal_type(list_int(1)) # revealed: list[Divergent] | list[int] + +def tuple_obj(cond: bool): + if cond: + x = object() + else: + x = tuple_obj(cond) + return (x,) + +reveal_type(tuple_obj(True)) # revealed: tuple[object] + +def get_non_empty(node): + for child in node.children: + node = get_non_empty(child) + if node is not None: + return node + return None + +reveal_type(get_non_empty(None)) # revealed: (Divergent & ~None) | None + +def nested_scope(): + def inner(): + return nested_scope() + return inner() + +reveal_type(nested_scope()) # revealed: Never + +def eager_nested_scope(): + class A: + x = eager_nested_scope() + + return A.x + +reveal_type(eager_nested_scope()) # revealed: Unknown + +class C: + def flip(self) -> "D": + return D() + +class D(C): + # TODO invalid override error + def flip(self) -> "C": + return C() + +def c_or_d(n: int): + if n == 0: + return D() + else: + return c_or_d(n - 1).flip() + +# In fixed-point iteration of the return type inference, the return type is monotonically widened. +# For example, once the return type of `c_or_d` is determined to be `C`, +# it will never be determined to be a subtype `D` in the subsequent iterations. +reveal_type(c_or_d(1)) # revealed: C +``` + +### Class method + +If a method's return type is not annotated, it is also inferred, but the inferred type is a union of +all possible return types and `Unknown`. This is because a method of a class may be overridden by +its subtypes. For example, if the return type of a method is inferred to be `int`, the type the +coder really intended might be `int | None`, in which case it would be impossible for the overridden +method to return `None`. + +```py +class C: + def f(self): + return 1 + +class D(C): + def f(self): + return None + +reveal_type(C().f()) # revealed: Literal[1] | Unknown +reveal_type(D().f()) # revealed: None | Literal[1] | Unknown +``` + +However, in the following cases, `Unknown` is not included in the inferred return type because there +is no ambiguity in the subclass. + +- The class or the method is marked as `final`. + +```py +from typing import final + +@final +class C: + def f(self): + return 1 + +class D: + @final + def f(self): + return "a" + +reveal_type(C().f()) # revealed: Literal[1] +reveal_type(D().f()) # revealed: Literal["a"] +``` + +- The method overrides the methods of the base classes, and the return types of the base class + methods are known (In this case, the return type of the method is the intersection of the return + types of the methods in the base classes). + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Literal + +class C: + def f(self) -> int: + return 1 + + def g[T](self, x: T) -> T: + return x + + def h[T: int](self, x: T) -> T: + return x + + def i[T: int](self, x: T) -> list[T]: + return [x] + +class D(C): + def f(self): + return 2 + # TODO: This should be an invalid-override error. + def g(self, x: int): + return 2 + # A strict application of the Liskov Substitution Principle would consider + # this an invalid override because it violates the guarantee that the method returns + # the same type as its input type (any type smaller than int), + # but neither mypy nor pyright will throw an error for this. + def h(self, x: int): + return 2 + + def i(self, x: int): + return [2] + +class E(D): + def f(self): + return 3 + +reveal_type(C().f()) # revealed: int +reveal_type(D().f()) # revealed: int +reveal_type(E().f()) # revealed: int +reveal_type(C().g(1)) # revealed: Literal[1] +reveal_type(D().g(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(1)) # revealed: Literal[1] +reveal_type(D().h(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(True)) # revealed: Literal[True] +reveal_type(D().h(True)) # revealed: Literal[2] | Unknown +reveal_type(C().i(1)) # revealed: list[Literal[1]] +# TODO: better type for list elements +reveal_type(D().i(1)) # revealed: list[Unknown | int] | list[Unknown] + +class F: + def f(self) -> Literal[1, 2]: + return 2 + +class G: + def f(self) -> Literal[2, 3]: + return 2 + +class H(F, G): + # TODO: should be an invalid-override error + def f(self): + raise NotImplementedError + +class I(F, G): + # TODO: should be an invalid-override error + @final + def f(self): + raise NotImplementedError + +# We use a return type of `F.f` according to the MRO. +reveal_type(H().f()) # revealed: Literal[1, 2] +reveal_type(I().f()) # revealed: Never + +class C2[T]: + def f(self, x: T) -> T: + return x + +class D2(C2[int]): + def f(self, x: int): + return x + +reveal_type(D2().f(1)) # revealed: int +``` + ## Invalid return type diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index de962d2075..06c0744d7a 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -142,20 +142,25 @@ def _(x: A | B): reveal_type(x) # revealed: A | B ``` -## No narrowing for custom `type` callable +## No special narrowing for custom `type` callable ```py +def type(x: object): + return int + class A: ... class B: ... -def type(x): - return int - def _(x: A | B): + # The custom `type` function always returns `int`, + # so any branch other than `type(...) is int` is unreachable. if type(x) is A: + reveal_type(x) # revealed: Never + # And the condition here is always `True` and has no effect on the narrowing of `x`. + elif type(x) is int: reveal_type(x) # revealed: A | B else: - reveal_type(x) # revealed: A | B + reveal_type(x) # revealed: Never ``` ## No narrowing for multiple arguments diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index c7c42241a3..1ec58a3299 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -1,6 +1,9 @@ use std::ops::Range; -use ruff_db::{files::File, parsed::ParsedModuleRef}; +use ruff_db::{ + files::File, + parsed::{ParsedModuleRef, parsed_module}, +}; use ruff_index::newtype_index; use ruff_python_ast as ast; @@ -27,6 +30,10 @@ pub struct ScopeId<'db> { impl get_size2::GetSize for ScopeId<'_> {} impl<'db> ScopeId<'db> { + pub(crate) fn is_non_lambda_function(self, db: &'db dyn Db) -> bool { + self.node(db).scope_kind().is_non_lambda_function() + } + pub(crate) fn is_annotation(self, db: &'db dyn Db) -> bool { self.node(db).scope_kind().is_annotation() } @@ -64,6 +71,18 @@ impl<'db> ScopeId<'db> { NodeWithScopeKind::GeneratorExpression(_) => "", } } + + pub(crate) fn is_coroutine_function(self, db: &'db dyn Db) -> bool { + let module = parsed_module(db, self.file(db)).load(db); + self.node(db) + .as_function() + .is_some_and(|func| func.node(&module).is_async && !self.is_generator_function(db)) + } + + pub(crate) fn is_generator_function(self, db: &'db dyn Db) -> bool { + let index = semantic_index(db, self.file(db)); + self.file_scope_id(db).is_generator_function(index) + } } /// ID that uniquely identifies a scope inside of a module. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 2e285d46a7..4c5e202115 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -6454,6 +6454,19 @@ impl<'db> Type<'db> { } } + /// Returns the inferred return type of `self` if it is a function literal / bound method. + fn infer_return_type(self, db: &'db dyn Db) -> Option> { + match self { + Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => { + Some(function_type.infer_return_type(db)) + } + Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => { + Some(method_type.infer_return_type(db)) + } + _ => None, + } + } + /// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if /// the arguments are not compatible with the formal parameters. /// @@ -12156,6 +12169,77 @@ impl<'db> BoundMethodType<'db> { ) } + /// Infers this method scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db); + let inference = infer_scope_types(db, scope); + inference.infer_return_type(db, scope, Type::BoundMethod(self)) + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn class_definition(self, db: &'db dyn Db) -> Option> { + let definition_scope = self.function(db).definition(db).scope(db); + let index = semantic_index(db, definition_scope.file(db)); + Some(index.expect_single_definition(definition_scope.node(db).as_class()?)) + } + + pub(crate) fn is_final(self, db: &'db dyn Db) -> bool { + if self + .function(db) + .has_known_decorator(db, FunctionDecorators::FINAL) + { + return true; + } + let Some(class_ty) = self + .class_definition(db) + .and_then(|class| binding_type(db, class).as_class_literal()) + else { + return false; + }; + class_ty + .known_function_decorators(db) + .any(|deco| deco == KnownFunction::Final) + } + + pub(super) fn base_return_type(self, db: &'db dyn Db) -> Option> { + let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; + let name = self.function(db).name(db); + + let base = class + .iter_mro(db) + .nth(1) + .and_then(class_base::ClassBase::into_class)?; + let base_member = base.class_member(db, name, MemberLookupPolicy::default()); + if let Place::Defined(Type::FunctionLiteral(base_func), _, _) = base_member.place { + if let [signature] = base_func.signature(db).overloads.as_slice() { + let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| { + let base_method_ty = + base_func.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.infer_return_type(db) + }); + if let Some(generic_context) = signature.generic_context.as_ref() { + // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. + Some( + unspecialized_return_ty + .apply_specialization(db, generic_context.unknown_specialization(db)), + ) + } else { + Some(unspecialized_return_ty) + } + } else { + // TODO: Handle overloaded base methods. + None + } + } else { + None + } + } + fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { Self::new( db, @@ -12233,6 +12317,24 @@ impl<'db> BoundMethodType<'db> { } } +fn return_type_cycle_recover<'db>( + db: &'db dyn Db, + cycle: &salsa::Cycle, + previous_return_type: &Type<'db>, + return_type: Type<'db>, + _self: BoundMethodType<'db>, +) -> Type<'db> { + return_type.cycle_normalized(db, *previous_return_type, cycle) +} + +fn return_type_cycle_initial<'db>( + _db: &'db dyn Db, + id: salsa::Id, + _method: BoundMethodType<'db>, +) -> Type<'db> { + Type::divergent(id) +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, get_size2::GetSize)] pub enum CallableTypeKind { /// Represents regular callable objects. diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 005013e70b..c4bfd62989 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3650,10 +3650,15 @@ impl<'db> Binding<'db> { } } } + for (keywords_index, keywords_type) in keywords_arguments { matcher.match_keyword_variadic(db, keywords_index, keywords_type); } - self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown()); + self.return_ty = self.signature.return_ty.unwrap_or_else(|| { + self.callable_type + .infer_return_type(db) + .unwrap_or(Type::unknown()) + }); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); self.variadic_argument_matched_to_variadic_parameter = matcher.variadic_argument_matched_to_variadic_parameter; diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index e727e6663b..b05aa95b82 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -85,7 +85,7 @@ use crate::types::{ HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext, TypeMapping, TypeRelation, UnionBuilder, binding_type, definition_expression_type, - infer_definition_types, walk_signature, + infer_definition_types, infer_scope_types, walk_signature, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -1197,6 +1197,32 @@ impl<'db> FunctionType<'db> { updated_last_definition_signature, )) } + + /// Infers this function scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.literal(db).last_definition(db).body_scope(db); + let inference = infer_scope_types(db, scope); + inference.infer_return_type(db, scope, Type::FunctionLiteral(self)) + } +} + +fn return_type_cycle_recover<'db>( + db: &'db dyn Db, + cycle: &salsa::Cycle, + previous_return_type: &Type<'db>, + return_type: Type<'db>, + _self: FunctionType<'db>, +) -> Type<'db> { + return_type.cycle_normalized(db, *previous_return_type, cycle) +} + +fn return_type_cycle_initial<'db>( + _db: &'db dyn Db, + id: salsa::Id, + _function: FunctionType<'db>, +) -> Type<'db> { + Type::divergent(id) } /// Evaluate an `isinstance` call. Return `Truthiness::AlwaysTrue` if we can definitely infer that diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b9adc93eb2..e565b84622 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -47,13 +47,13 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::{SemanticIndex, semantic_index}; +use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::function::FunctionType; use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, declaration_type, + ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers, UnionBuilder, declaration_type, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -546,6 +546,9 @@ struct ScopeInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, + + /// The returned types, if it is a function body. + return_types: Vec>, } impl<'db> ScopeInference<'db> { @@ -605,6 +608,51 @@ impl<'db> ScopeInference<'db> { extra.string_annotations.contains(&expression.into()) } + + /// Returns the inferred return type of this function body (union of all possible return types), + /// or `None` if the region is not a function body. + /// In the case of methods, the return type of the superclass method is further unioned. + /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. + pub(crate) fn infer_return_type( + &self, + db: &'db dyn Db, + scope: ScopeId<'db>, + callee_ty: Type<'db>, + ) -> Type<'db> { + // TODO: coroutine function type inference + // TODO: generator function type inference + if scope.is_coroutine_function(db) || scope.is_generator_function(db) { + return Type::unknown(); + } + + let mut union = UnionBuilder::new(db); + if let Some(cycle_recovery) = self.fallback_type() { + union = union.add(cycle_recovery); + } + + let Some(extra) = &self.extra else { + unreachable!( + "infer_return_type should only be called on a function body scope inference" + ); + }; + for return_ty in &extra.return_types { + union = union.add(*return_ty); + } + let use_def = use_def_map(db, scope); + if use_def.can_implicitly_return_none(db) { + union = union.add(Type::none(db)); + } + if let Type::BoundMethod(method_ty) = callee_ty { + // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. + // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. + // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. + if !method_ty.is_final(db) { + union = union.add(method_ty.base_return_type(db).unwrap_or(Type::unknown())); + } + } + + union.build() + } } /// The inferred types for a definition region. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 76271b07ff..c6577a0b3e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -12688,6 +12688,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { pub(super) fn finish_scope(mut self) -> ScopeInference<'db> { self.infer_region(); + let db = self.db(); let Self { context, @@ -12714,21 +12715,27 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { called_functions: _, index: _, region: _, - return_types_and_ranges: _, + return_types_and_ranges, } = self; - let _ = scope; let diagnostics = context.finish(); - let extra = - (!string_annotations.is_empty() || !diagnostics.is_empty() || cycle_recovery.is_some()) - .then(|| { - Box::new(ScopeInferenceExtra { - string_annotations, - cycle_recovery, - diagnostics, - }) - }); + let extra = (!string_annotations.is_empty() + || !diagnostics.is_empty() + || cycle_recovery.is_some() + || scope.is_non_lambda_function(db)) + .then(|| { + let return_types = return_types_and_ranges + .into_iter() + .map(|ty_range| ty_range.ty) + .collect(); + Box::new(ScopeInferenceExtra { + string_annotations, + cycle_recovery, + diagnostics, + return_types, + }) + }); expressions.shrink_to_fit();