From e84d523bcf95681ba74cec5a0eea1aff918a5fbe Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 17 Sep 2025 18:51:50 -0400 Subject: [PATCH] [ty] Infer more precise types for collection literals (#20360) ## Summary Part of https://github.com/astral-sh/ty/issues/168. Infer more precise types for collection literals (currently, only `list` and `set`). For example, ```py x = [1, 2, 3] # revealed: list[Unknown | int] y: list[int] = [1, 2, 3] # revealed: list[int] ``` This could easily be extended to `dict` literals, but I am intentionally limiting scope for now. --- .../mdtest/assignment/annotations.md | 72 +++++++++ .../resources/mdtest/del.md | 4 +- .../resources/mdtest/import/dunder_all.md | 3 +- .../mdtest/literal/collections/list.md | 28 +++- .../mdtest/literal/collections/set.md | 28 +++- .../mdtest/narrow/conditionals/nested.md | 11 +- .../resources/mdtest/subscript/lists.md | 8 +- .../resources/mdtest/type_compendium/tuple.md | 3 +- .../resources/mdtest/unpacking.md | 5 +- crates/ty_python_semantic/src/types.rs | 57 +++++--- crates/ty_python_semantic/src/types/class.rs | 16 +- .../src/types/diagnostic.rs | 2 +- .../ty_python_semantic/src/types/function.rs | 7 +- crates/ty_python_semantic/src/types/infer.rs | 30 +++- .../src/types/infer/builder.rs | 137 +++++++++++++++--- crates/ty_python_semantic/src/types/tuple.rs | 8 +- 16 files changed, 341 insertions(+), 78 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index f684fd6f90..62852bc0ba 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -79,6 +79,78 @@ b: tuple[int] = ("foo",) c: tuple[str | int, str] = ([], "foo") ``` +## Collection literal annotations are understood + +```toml +[environment] +python-version = "3.12" +``` + +```py +import typing + +a: list[int] = [1, 2, 3] +reveal_type(a) # revealed: list[int] + +b: list[int | str] = [1, 2, 3] +reveal_type(b) # revealed: list[int | str] + +c: typing.List[int] = [1, 2, 3] +reveal_type(c) # revealed: list[int] + +d: list[typing.Any] = [] +reveal_type(d) # revealed: list[Any] + +e: set[int] = {1, 2, 3} +reveal_type(e) # revealed: set[int] + +f: set[int | str] = {1, 2, 3} +reveal_type(f) # revealed: set[int | str] + +g: typing.Set[int] = {1, 2, 3} +reveal_type(g) # revealed: set[int] + +h: list[list[int]] = [[], [42]] +reveal_type(h) # revealed: list[list[int]] + +i: list[typing.Any] = [1, 2, "3", ([4],)] +reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]] + +j: list[tuple[str | int, ...]] = [(1, 2), ("foo", "bar"), ()] +reveal_type(j) # revealed: list[tuple[str | int, ...]] + +k: list[tuple[list[int], ...]] = [([],), ([1, 2], [3, 4]), ([5], [6], [7])] +reveal_type(k) # revealed: list[tuple[list[int], ...]] + +l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"]) +reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]] + +type IntList = list[int] + +m: IntList = [1, 2, 3] +reveal_type(m) # revealed: list[int] + +# TODO: this should type-check and avoid literal promotion +# error: [invalid-assignment] "Object of type `list[int]` is not assignable to `list[Literal[1, 2, 3]]`" +n: list[typing.Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(n) # revealed: list[Literal[1, 2, 3]] + +# TODO: this should type-check and avoid literal promotion +# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `list[LiteralString]`" +o: list[typing.LiteralString] = ["a", "b", "c"] +reveal_type(o) # revealed: list[LiteralString] +``` + +## Incorrect collection literal assignments are complained aobut + +```py +# error: [invalid-assignment] "Object of type `list[int]` is not assignable to `list[str]`" +a: list[str] = [1, 2, 3] + +# error: [invalid-assignment] "Object of type `set[int | str]` is not assignable to `set[int]`" +b: set[int] = {1, 2, "3"} +``` + ## PEP-604 annotations are supported ```py diff --git a/crates/ty_python_semantic/resources/mdtest/del.md b/crates/ty_python_semantic/resources/mdtest/del.md index 7ba1505906..e0d8a495c2 100644 --- a/crates/ty_python_semantic/resources/mdtest/del.md +++ b/crates/ty_python_semantic/resources/mdtest/del.md @@ -46,7 +46,7 @@ def delete(): del d # error: [unresolved-reference] "Name `d` used when not defined" delete() -reveal_type(d) # revealed: list[@Todo(list literal element type)] +reveal_type(d) # revealed: list[Unknown | int] def delete_element(): # When the `del` target isn't a name, it doesn't force local resolution. @@ -62,7 +62,7 @@ def delete_global(): delete_global() # Again, the variable should have been removed, but we don't check it. -reveal_type(d) # revealed: list[@Todo(list literal element type)] +reveal_type(d) # revealed: list[Unknown | int] def delete_nonlocal(): e = 2 diff --git a/crates/ty_python_semantic/resources/mdtest/import/dunder_all.md b/crates/ty_python_semantic/resources/mdtest/import/dunder_all.md index 7fbbb5907e..1fcdefb2dd 100644 --- a/crates/ty_python_semantic/resources/mdtest/import/dunder_all.md +++ b/crates/ty_python_semantic/resources/mdtest/import/dunder_all.md @@ -783,9 +783,8 @@ class A: ... ```py from subexporter import * -# TODO: Should be `list[str]` # TODO: Should we avoid including `Unknown` for this case? -reveal_type(__all__) # revealed: Unknown | list[@Todo(list literal element type)] +reveal_type(__all__) # revealed: Unknown | list[Unknown | str] __all__.append("B") diff --git a/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md b/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md index 44f7eceec2..13c02c15b8 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md +++ b/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md @@ -3,7 +3,33 @@ ## Empty list ```py -reveal_type([]) # revealed: list[@Todo(list literal element type)] +reveal_type([]) # revealed: list[Unknown] +``` + +## List of tuples + +```py +reveal_type([(1, 2), (3, 4)]) # revealed: list[Unknown | tuple[int, int]] +``` + +## List of functions + +```py +def a(_: int) -> int: + return 0 + +def b(_: int) -> int: + return 1 + +x = [a, b] +reveal_type(x) # revealed: list[Unknown | ((_: int) -> int)] +``` + +## Mixed list + +```py +# revealed: list[Unknown | int | tuple[int, int] | tuple[int, int, int]] +reveal_type([1, (1, 2), (1, 2, 3)]) ``` ## List comprehensions diff --git a/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md b/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md index 39cd5ed5fa..6c6855e40e 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md +++ b/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md @@ -3,7 +3,33 @@ ## Basic set ```py -reveal_type({1, 2}) # revealed: set[@Todo(set literal element type)] +reveal_type({1, 2}) # revealed: set[Unknown | int] +``` + +## Set of tuples + +```py +reveal_type({(1, 2), (3, 4)}) # revealed: set[Unknown | tuple[int, int]] +``` + +## Set of functions + +```py +def a(_: int) -> int: + return 0 + +def b(_: int) -> int: + return 1 + +x = {a, b} +reveal_type(x) # revealed: set[Unknown | ((_: int) -> int)] +``` + +## Mixed set + +```py +# revealed: set[Unknown | int | tuple[int, int] | tuple[int, int, int]] +reveal_type({1, (1, 2), (1, 2, 3)}) ``` ## Set comprehensions diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md index 79b9dbb62b..52750aec89 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md @@ -310,17 +310,13 @@ no longer valid in the inner lazy scope. def f(l: list[str | None]): if l[0] is not None: def _(): - # TODO: should be `str | None` - reveal_type(l[0]) # revealed: str | None | @Todo(list literal element type) - # TODO: should be of type `list[None]` + reveal_type(l[0]) # revealed: str | None | Unknown l = [None] def f(l: list[str | None]): l[0] = "a" def _(): - # TODO: should be `str | None` - reveal_type(l[0]) # revealed: str | None | @Todo(list literal element type) - # TODO: should be of type `list[None]` + reveal_type(l[0]) # revealed: str | None | Unknown l = [None] def f(l: list[str | None]): @@ -328,8 +324,7 @@ def f(l: list[str | None]): def _(): l: list[str | None] = [None] def _(): - # TODO: should be `str | None` - reveal_type(l[0]) # revealed: @Todo(list literal element type) + reveal_type(l[0]) # revealed: str | None def _(): def _(): diff --git a/crates/ty_python_semantic/resources/mdtest/subscript/lists.md b/crates/ty_python_semantic/resources/mdtest/subscript/lists.md index 2954092d2c..d4026b1995 100644 --- a/crates/ty_python_semantic/resources/mdtest/subscript/lists.md +++ b/crates/ty_python_semantic/resources/mdtest/subscript/lists.md @@ -9,13 +9,11 @@ A list can be indexed into with: ```py x = [1, 2, 3] -reveal_type(x) # revealed: list[@Todo(list literal element type)] +reveal_type(x) # revealed: list[Unknown | int] -# TODO reveal int -reveal_type(x[0]) # revealed: @Todo(list literal element type) +reveal_type(x[0]) # revealed: Unknown | int -# TODO reveal list[int] -reveal_type(x[0:1]) # revealed: list[@Todo(list literal element type)] +reveal_type(x[0:1]) # revealed: list[Unknown | int] # error: [invalid-argument-type] reveal_type(x["a"]) # revealed: Unknown diff --git a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md index ab9baee1a4..1f1d52a57f 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md +++ b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md @@ -55,8 +55,7 @@ def f(x: Iterable[int], y: list[str], z: Never, aa: list[Never], bb: LiskovUncom reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]] -# TODO: should be `tuple[Literal[1], ...]` -reveal_type(tuple([1])) # revealed: tuple[@Todo(list literal element type), ...] +reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...] # error: [invalid-argument-type] reveal_type(tuple[int]([1])) # revealed: tuple[int] diff --git a/crates/ty_python_semantic/resources/mdtest/unpacking.md b/crates/ty_python_semantic/resources/mdtest/unpacking.md index 377fec25fb..20a89625ed 100644 --- a/crates/ty_python_semantic/resources/mdtest/unpacking.md +++ b/crates/ty_python_semantic/resources/mdtest/unpacking.md @@ -213,9 +213,8 @@ reveal_type(d) # revealed: Literal[2] ```py a, b = [1, 2] -# TODO: should be `int` for both `a` and `b` -reveal_type(a) # revealed: @Todo(list literal element type) -reveal_type(b) # revealed: @Todo(list literal element type) +reveal_type(a) # revealed: Unknown | int +reveal_type(b) # revealed: Unknown | int ``` ### Simple unpacking diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index eef84552b7..13e1a62278 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1130,11 +1130,30 @@ impl<'db> Type<'db> { Type::IntLiteral(_) => Some(KnownClass::Int.to_instance(db)), Type::BytesLiteral(_) => Some(KnownClass::Bytes.to_instance(db)), Type::ModuleLiteral(_) => Some(KnownClass::ModuleType.to_instance(db)), + Type::FunctionLiteral(_) => Some(KnownClass::FunctionType.to_instance(db)), Type::EnumLiteral(literal) => Some(literal.enum_class_instance(db)), _ => None, } } + /// If this type is a literal, promote it to a type that this literal is an instance of. + /// + /// Note that this function tries to promote literals to a more user-friendly form than their + /// fallback instance type. For example, `def _() -> int` is promoted to `Callable[[], int]`, + /// as opposed to `FunctionType`. + pub(crate) fn literal_promotion_type(self, db: &'db dyn Db) -> Option> { + match self { + Type::StringLiteral(_) | Type::LiteralString => Some(KnownClass::Str.to_instance(db)), + Type::BooleanLiteral(_) => Some(KnownClass::Bool.to_instance(db)), + Type::IntLiteral(_) => Some(KnownClass::Int.to_instance(db)), + Type::BytesLiteral(_) => Some(KnownClass::Bytes.to_instance(db)), + Type::ModuleLiteral(_) => Some(KnownClass::ModuleType.to_instance(db)), + Type::EnumLiteral(literal) => Some(literal.enum_class_instance(db)), + Type::FunctionLiteral(literal) => Some(Type::Callable(literal.into_callable_type(db))), + _ => None, + } + } + /// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID. /// /// A normalized type: @@ -1704,18 +1723,13 @@ impl<'db> Type<'db> { | Type::IntLiteral(_) | Type::BytesLiteral(_) | Type::ModuleLiteral(_) - | Type::EnumLiteral(_), + | Type::EnumLiteral(_) + | Type::FunctionLiteral(_), _, ) => (self.literal_fallback_instance(db)).when_some_and(|instance| { instance.has_relation_to_impl(db, target, relation, visitor) }), - // A `FunctionLiteral` type is a single-valued type like the other literals handled above, - // so it also, for now, just delegates to its instance fallback. - (Type::FunctionLiteral(_), _) => KnownClass::FunctionType - .to_instance(db) - .has_relation_to_impl(db, target, relation, visitor), - // The same reasoning applies for these special callable types: (Type::BoundMethod(_), _) => KnownClass::MethodType .to_instance(db) @@ -5979,8 +5993,9 @@ impl<'db> Type<'db> { self } } - TypeMapping::PromoteLiterals | TypeMapping::BindLegacyTypevars(_) | - TypeMapping::MarkTypeVarsInferable(_) => self, + TypeMapping::PromoteLiterals + | TypeMapping::BindLegacyTypevars(_) + | TypeMapping::MarkTypeVarsInferable(_) => self, TypeMapping::Materialize(materialization_kind) => { Type::TypeVar(bound_typevar.materialize_impl(db, *materialization_kind, visitor)) } @@ -6000,10 +6015,10 @@ impl<'db> Type<'db> { self } } - TypeMapping::PromoteLiterals | - TypeMapping::BindLegacyTypevars(_) | - TypeMapping::BindSelf(_) | - TypeMapping::ReplaceSelf { .. } + TypeMapping::PromoteLiterals + | TypeMapping::BindLegacyTypevars(_) + | TypeMapping::BindSelf(_) + | TypeMapping::ReplaceSelf { .. } => self, TypeMapping::Materialize(materialization_kind) => Type::NonInferableTypeVar(bound_typevar.materialize_impl(db, *materialization_kind, visitor)) @@ -6023,7 +6038,13 @@ impl<'db> Type<'db> { } Type::FunctionLiteral(function) => { - Type::FunctionLiteral(function.with_type_mapping(db, type_mapping)) + let function = Type::FunctionLiteral(function.with_type_mapping(db, type_mapping)); + + match type_mapping { + TypeMapping::PromoteLiterals => function.literal_promotion_type(db) + .expect("function literal should have a promotion type"), + _ => function + } } Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new( @@ -6129,8 +6150,8 @@ impl<'db> Type<'db> { TypeMapping::ReplaceSelf { .. } | TypeMapping::MarkTypeVarsInferable(_) | TypeMapping::Materialize(_) => self, - TypeMapping::PromoteLiterals => self.literal_fallback_instance(db) - .expect("literal type should have fallback instance type"), + TypeMapping::PromoteLiterals => self.literal_promotion_type(db) + .expect("literal type should have a promotion type"), } Type::Dynamic(_) => match type_mapping { @@ -6663,8 +6684,8 @@ pub enum TypeMapping<'a, 'db> { Specialization(Specialization<'db>), /// Applies a partial specialization to the type PartialSpecialization(PartialSpecialization<'a, 'db>), - /// Promotes any literal types to their corresponding instance types (e.g. `Literal["string"]` - /// to `str`) + /// Replaces any literal types with their corresponding promoted type form (e.g. `Literal["string"]` + /// to `str`, or `def _() -> int` to `Callable[[], int]`). PromoteLiterals, /// Binds a legacy typevar with the generic context (class, function, type alias) that it is /// being used in. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index bfacec997a..f31a86b3e3 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1048,7 +1048,7 @@ impl<'db> ClassType<'db> { /// Return a callable type (or union of callable types) that represents the callable /// constructor signature of this class. - #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(cycle_fn=into_callable_cycle_recover, cycle_initial=into_callable_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(super) fn into_callable(self, db: &'db dyn Db) -> Type<'db> { let self_ty = Type::from(self); let metaclass_dunder_call_function_symbol = self_ty @@ -1208,6 +1208,20 @@ impl<'db> ClassType<'db> { } } +#[allow(clippy::trivially_copy_pass_by_ref)] +fn into_callable_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: ClassType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { + Type::Never +} + impl<'db> From> for ClassType<'db> { fn from(generic: GenericAlias<'db>) -> ClassType<'db> { ClassType::Generic(generic) diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index f7f5587d6a..1751678f82 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -2626,7 +2626,7 @@ pub(crate) fn report_undeclared_protocol_member( let binding_type = binding_type(db, definition); let suggestion = binding_type - .literal_fallback_instance(db) + .literal_promotion_type(db) .unwrap_or(binding_type); if should_give_hint(db, suggestion) { diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 8ed70e6665..c7b94f0584 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1081,16 +1081,13 @@ fn is_instance_truthiness<'db>( | Type::StringLiteral(..) | Type::LiteralString | Type::ModuleLiteral(..) - | Type::EnumLiteral(..) => always_true_if( + | Type::EnumLiteral(..) + | Type::FunctionLiteral(..) => always_true_if( ty.literal_fallback_instance(db) .as_ref() .is_some_and(is_instance), ), - Type::FunctionLiteral(..) => { - always_true_if(is_instance(&KnownClass::FunctionType.to_instance(db))) - } - Type::ClassLiteral(..) => always_true_if(is_instance(&KnownClass::Type.to_instance(db))), Type::TypeAlias(alias) => is_instance_truthiness(db, alias.value_type(db), class), diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 3c0806373c..7f81248a22 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -49,8 +49,9 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::types::diagnostic::TypeCheckDiagnostics; +use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; -use crate::types::{ClassLiteral, Truthiness, Type, TypeAndQualifiers}; +use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers}; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -355,10 +356,31 @@ pub(crate) struct TypeContext<'db> { } impl<'db> TypeContext<'db> { - pub(crate) fn new(annotation: Type<'db>) -> Self { - Self { - annotation: Some(annotation), + pub(crate) fn new(annotation: Option>) -> Self { + Self { annotation } + } + + // If the type annotation is a specialized instance of the given `KnownClass`, returns the + // specialization. + fn known_specialization( + &self, + known_class: KnownClass, + db: &'db dyn Db, + ) -> Option> { + let class_type = match self.annotation? { + Type::NominalInstance(instance) => instance, + Type::TypeAlias(alias) => alias.value_type(db).into_nominal_instance()?, + _ => return None, } + .class(db); + + if !class_type.is_known(db, known_class) { + return None; + } + + class_type + .into_generic_alias() + .map(|generic_alias| generic_alias.specialization(db)) } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 34638b8bb3..99dfbe428e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -73,13 +73,13 @@ use crate::types::diagnostic::{ use crate::types::function::{ FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral, }; -use crate::types::generics::LegacyGenericBase; use crate::types::generics::{GenericContext, bind_typevar}; +use crate::types::generics::{LegacyGenericBase, SpecializationBuilder}; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::Signature; use crate::types::subclass_of::SubclassOfInner; -use crate::types::tuple::{Tuple, TupleSpec, TupleType}; +use crate::types::tuple::{Tuple, TupleLength, TupleSpec, TupleType}; use crate::types::typed_dict::{ TypedDictAssignmentKind, validate_typed_dict_constructor, validate_typed_dict_dict_literal, validate_typed_dict_key_assignment, @@ -90,8 +90,9 @@ use crate::types::{ IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, - TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, - TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, + TypeContext, TypeMapping, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, + TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, + todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -4008,7 +4009,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(value) = value { self.infer_maybe_standalone_expression( value, - TypeContext::new(annotated.inner_type()), + TypeContext::new(Some(annotated.inner_type())), ); } @@ -4101,8 +4102,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { debug_assert!(PlaceExpr::try_from_expr(target).is_some()); if let Some(value) = value { - let inferred_ty = self - .infer_maybe_standalone_expression(value, TypeContext::new(declared.inner_type())); + let inferred_ty = self.infer_maybe_standalone_expression( + value, + TypeContext::new(Some(declared.inner_type())), + ); let mut inferred_ty = if target .as_name_expr() .is_some_and(|name| &name.id == "TYPE_CHECKING") @@ -5236,7 +5239,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_tuple_expression( &mut self, tuple: &ast::ExprTuple, - _tcx: TypeContext<'db>, + tcx: TypeContext<'db>, ) -> Type<'db> { let ast::ExprTuple { range: _, @@ -5246,11 +5249,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { parenthesized: _, } = tuple; + let annotated_tuple = tcx + .known_specialization(KnownClass::Tuple, self.db()) + .and_then(|specialization| { + specialization + .tuple(self.db()) + .expect("the specialization of `KnownClass::Tuple` must have a tuple spec") + .resize(self.db(), TupleLength::Fixed(elts.len())) + .ok() + }); + + let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements); + let db = self.db(); let divergent = Type::divergent(self.scope()); let element_types = elts.iter().map(|element| { - // TODO: Use the type context for more precise inference. - let element_type = self.infer_expression(element, TypeContext::default()); + let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied(); + let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty)); + if element_type.has_divergent_type(self.db(), divergent) { divergent } else { @@ -5261,7 +5277,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::heterogeneous_tuple(db, element_types) } - fn infer_list_expression(&mut self, list: &ast::ExprList, _tcx: TypeContext<'db>) -> Type<'db> { + fn infer_list_expression(&mut self, list: &ast::ExprList, tcx: TypeContext<'db>) -> Type<'db> { let ast::ExprList { range: _, node_index: _, @@ -5269,28 +5285,102 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ctx: _, } = list; - // TODO: Use the type context for more precise inference. - for elt in elts { - self.infer_expression(elt, TypeContext::default()); - } - - KnownClass::List - .to_specialized_instance(self.db(), [todo_type!("list literal element type")]) + self.infer_collection_literal(elts, tcx, KnownClass::List) + .unwrap_or_else(|| { + KnownClass::List.to_specialized_instance(self.db(), [Type::unknown()]) + }) } - fn infer_set_expression(&mut self, set: &ast::ExprSet, _tcx: TypeContext<'db>) -> Type<'db> { + fn infer_set_expression(&mut self, set: &ast::ExprSet, tcx: TypeContext<'db>) -> Type<'db> { let ast::ExprSet { range: _, node_index: _, elts, } = set; - // TODO: Use the type context for more precise inference. - for elt in elts { - self.infer_expression(elt, TypeContext::default()); + self.infer_collection_literal(elts, tcx, KnownClass::Set) + .unwrap_or_else(|| { + KnownClass::Set.to_specialized_instance(self.db(), [Type::unknown()]) + }) + } + + // Infer the type of a collection literal expression. + fn infer_collection_literal( + &mut self, + elts: &[ast::Expr], + tcx: TypeContext<'db>, + collection_class: KnownClass, + ) -> Option> { + // Extract the type variable `T` from `list[T]` in typeshed. + fn elts_ty( + collection_class: KnownClass, + db: &dyn Db, + ) -> Option<(ClassLiteral<'_>, Type<'_>)> { + let class_literal = collection_class.try_to_class_literal(db)?; + let generic_context = class_literal.generic_context(db)?; + let variables = generic_context.variables(db); + let elts_ty = variables.iter().exactly_one().ok()?; + Some((class_literal, Type::TypeVar(*elts_ty))) } - KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")]) + let annotated_elts_ty = tcx + .known_specialization(collection_class, self.db()) + .and_then(|specialization| specialization.types(self.db()).iter().exactly_one().ok()) + .copied(); + + let (class_literal, elts_ty) = elts_ty(collection_class, self.db()).unwrap_or_else(|| { + let name = collection_class.name(self.db()); + panic!("Typeshed should always have a `{name}` class in `builtins.pyi` with a single type variable") + }); + + let mut elements_are_assignable = true; + let mut inferred_elt_tys = Vec::with_capacity(elts.len()); + + // Infer the type of each element in the collection literal. + for elt in elts { + let inferred_elt_ty = self.infer_expression(elt, TypeContext::new(annotated_elts_ty)); + inferred_elt_tys.push(inferred_elt_ty); + + if let Some(annotated_elts_ty) = annotated_elts_ty { + elements_are_assignable &= + inferred_elt_ty.is_assignable_to(self.db(), annotated_elts_ty); + } + } + + // Create a set of constraints to infer a precise type for `T`. + let mut builder = SpecializationBuilder::new(self.db()); + + match annotated_elts_ty { + // If the inferred type of any element is not assignable to the type annotation, we + // ignore it, as to provide a more precise error message. + Some(_) if !elements_are_assignable => {} + + // Otherwise, the annotated type acts as a constraint for `T`. + // + // Note that we infer the annotated type _before_ the elements, to closer match the order + // of any unions written in the type annotation. + Some(annotated_elts_ty) => { + builder.infer(elts_ty, annotated_elts_ty).ok()?; + } + + // If a valid type annotation was not provided, avoid restricting the type of the collection + // by unioning the inferred type with `Unknown`. + None => builder.infer(elts_ty, Type::unknown()).ok()?, + } + + // The inferred type of each element acts as an additional constraint on `T`. + for inferred_elt_ty in inferred_elt_tys { + // Convert any element literals to their promoted type form to avoid excessively large + // unions for large nested list literals, which the constraint solver struggles with. + let inferred_elt_ty = + inferred_elt_ty.apply_type_mapping(self.db(), &TypeMapping::PromoteLiterals); + builder.infer(elts_ty, inferred_elt_ty).ok()?; + } + + let class_type = class_literal + .apply_specialization(self.db(), |generic_context| builder.build(generic_context)); + + Type::from(class_type).to_instance(self.db()) } fn infer_dict_expression(&mut self, dict: &ast::ExprDict, _tcx: TypeContext<'db>) -> Type<'db> { @@ -5314,6 +5404,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ], ) } + /// Infer the type of the `iter` expression of the first comprehension. fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) { let mut comprehensions_iter = comprehensions.iter(); diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 5b76defefd..2f10e79949 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -545,11 +545,15 @@ impl VariableLengthTuple { }) } - fn prefix_elements(&self) -> impl DoubleEndedIterator + ExactSizeIterator + '_ { + pub(crate) fn prefix_elements( + &self, + ) -> impl DoubleEndedIterator + ExactSizeIterator + '_ { self.prefix.iter() } - fn suffix_elements(&self) -> impl DoubleEndedIterator + ExactSizeIterator + '_ { + pub(crate) fn suffix_elements( + &self, + ) -> impl DoubleEndedIterator + ExactSizeIterator + '_ { self.suffix.iter() }