diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 24b04f68f3..fe0bbf84fd 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7] reveal_type(k) # revealed: list[tuple[list[int], ...]] l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"]) -# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]` -reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]] +reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]] type IntList = list[int] @@ -416,13 +415,14 @@ a = f("a") reveal_type(a) # revealed: list[Literal["a"]] b: list[int | Literal["a"]] = f("a") -reveal_type(b) # revealed: list[int | Literal["a"]] +reveal_type(b) # revealed: list[Literal["a"] | int] c: list[int | str] = f("a") -reveal_type(c) # revealed: list[int | str] +reveal_type(c) # revealed: list[str | int] d: list[int | tuple[int, int]] = f((1, 2)) -reveal_type(d) # revealed: list[int | tuple[int, int]] +# TODO: We could avoid reordering the union elements here. +reveal_type(d) # revealed: list[tuple[int, int] | int] e: list[int] = f(True) reveal_type(e) # revealed: list[int] @@ -437,8 +437,49 @@ def f2[T: int](x: T) -> T: return x i: int = f2(True) -reveal_type(i) # revealed: int +reveal_type(i) # revealed: Literal[True] j: int | str = f2(True) reveal_type(j) # revealed: Literal[True] ``` + +Types are not widened unnecessarily: + +```py +def id[T](x: T) -> T: + return x + +def lst[T](x: T) -> list[T]: + return [x] + +def _(i: int): + a: int | None = i + b: int | None = id(i) + c: int | str | None = id(i) + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int + reveal_type(c) # revealed: int + + a: list[int | None] | None = [i] + b: list[int | None] | None = id([i]) + c: list[int | None] | int | None = id([i]) + reveal_type(a) # revealed: list[int | None] + # TODO: these should reveal `list[int | None]` + # we currently do not use the call expression annotation as type context for argument inference + reveal_type(b) # revealed: list[Unknown | int] + reveal_type(c) # revealed: list[Unknown | int] + + a: list[int | None] | None = [i] + b: list[int | None] | None = lst(i) + c: list[int | None] | int | None = lst(i) + reveal_type(a) # revealed: list[int | None] + reveal_type(b) # revealed: list[int | None] + reveal_type(c) # revealed: list[int | None] + + a: list | None = [] + b: list | None = id([]) + c: list | int | None = id([]) + reveal_type(a) # revealed: list[Unknown] + reveal_type(b) # revealed: list[Unknown] + reveal_type(c) # revealed: list[Unknown] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md index 7b6a4369cc..f091a1c991 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md @@ -11,7 +11,7 @@ class Member: role: str = field(default="user") tag: str | None = field(default=None, init=False) -# revealed: (self: Member, name: str, role: str = str) -> None +# revealed: (self: Member, name: str, role: str = Literal["user"]) -> None reveal_type(Member.__init__) alice = Member(name="Alice", role="admin") @@ -37,7 +37,7 @@ class Data: content: list[int] = field(default_factory=list) timestamp: datetime = field(default_factory=datetime.now, init=False) -# revealed: (self: Data, content: list[int] = list[int]) -> None +# revealed: (self: Data, content: list[int] = Unknown) -> None reveal_type(Data.__init__) data = Data([1, 2, 3]) @@ -64,7 +64,7 @@ class Person: role: str = field(default="user", kw_only=True) # TODO: this would ideally show a default value of `None` for `age` -# revealed: (self: Person, name: str, *, age: int | None = int | None, role: str = str) -> None +# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None reveal_type(Person.__init__) alice = Person(role="admin", name="Alice") diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 8fc8b2fcb6..b9e3015c9f 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child} nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}} -# TODO: this should be an error (invalid type for `name` in innermost node) +# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`" nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}} ``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index f45e4127c7..1e1c9f0965 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1233,22 +1233,35 @@ impl<'db> Type<'db> { if yes { self.negate(db) } else { *self } } - /// Remove the union elements that are not related to `target`. + /// If the type is a union, filters union elements based on the provided predicate. + /// + /// Otherwise, returns the type unchanged. + pub(crate) fn filter_union( + self, + db: &'db dyn Db, + f: impl FnMut(&Type<'db>) -> bool, + ) -> Type<'db> { + if let Type::Union(union) = self { + union.filter(db, f) + } else { + self + } + } + + /// If the type is a union, removes union elements that are disjoint from `target`. + /// + /// Otherwise, returns the type unchanged. pub(crate) fn filter_disjoint_elements( self, db: &'db dyn Db, target: Type<'db>, inferable: InferableTypeVars<'_, 'db>, ) -> Type<'db> { - if let Type::Union(union) = self { - union.filter(db, |elem| { - !elem - .when_disjoint_from(db, target, inferable) - .is_always_satisfied() - }) - } else { - self - } + self.filter_union(db, |elem| { + !elem + .when_disjoint_from(db, target, inferable) + .is_always_satisfied() + }) } /// Returns the fallback instance type that a literal is an instance of, or `None` if the type @@ -11185,9 +11198,9 @@ impl<'db> UnionType<'db> { pub(crate) fn filter( self, db: &'db dyn Db, - filter_fn: impl FnMut(&&Type<'db>) -> bool, + mut f: impl FnMut(&Type<'db>) -> bool, ) -> Type<'db> { - Self::from_elements(db, self.elements(db).iter().filter(filter_fn)) + Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty))) } pub(crate) fn map_with_boundness( diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 641fce2ad9..2260350b46 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2524,6 +2524,7 @@ struct ArgumentTypeChecker<'a, 'db> { argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], call_expression_tcx: &'a TypeContext<'db>, + return_ty: Type<'db>, errors: &'a mut Vec>, inferable_typevars: InferableTypeVars<'db, 'db>, @@ -2531,6 +2532,7 @@ struct ArgumentTypeChecker<'a, 'db> { } impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { + #[expect(clippy::too_many_arguments)] fn new( db: &'db dyn Db, signature: &'a Signature<'db>, @@ -2538,6 +2540,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], call_expression_tcx: &'a TypeContext<'db>, + return_ty: Type<'db>, errors: &'a mut Vec>, ) -> Self { Self { @@ -2547,6 +2550,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_matches, parameter_tys, call_expression_tcx, + return_ty, errors, inferable_typevars: InferableTypeVars::None, specialization: None, @@ -2588,25 +2592,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // TODO: Use the list of inferable typevars from the generic context of the callable. let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars); - // Note that we infer the annotated type _before_ the arguments if this call is part of - // an annotated assignment, to closer match the order of any unions written in the type - // annotation. - if let Some(return_ty) = self.signature.return_ty - && let Some(call_expression_tcx) = self.call_expression_tcx.annotation - { - match call_expression_tcx { - // A type variable is not a useful type-context for expression inference, and applying it - // to the return type can lead to confusing unions in nested generic calls. - Type::TypeVar(_) => {} - - _ => { - // Ignore any specialization errors here, because the type context is only used as a hint - // to infer a more assignable return type. - let _ = builder.infer(return_ty, call_expression_tcx); - } - } - } - let parameters = self.signature.parameters(); for (argument_index, adjusted_argument_index, _, argument_type) in self.enumerate_argument_types() @@ -2631,7 +2616,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } - self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx)); + // Build the specialization first without inferring the type context. + let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx); + let isolated_return_ty = self + .return_ty + .apply_specialization(self.db, isolated_specialization); + + let mut try_infer_tcx = || { + let return_ty = self.signature.return_ty?; + let call_expression_tcx = self.call_expression_tcx.annotation?; + + // A type variable is not a useful type-context for expression inference, and applying it + // to the return type can lead to confusing unions in nested generic calls. + if call_expression_tcx.is_type_var() { + return None; + } + + // If the return type is already assignable to the annotated type, we can ignore the + // type context and prefer the narrower inferred type. + if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) { + return None; + } + + // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an + // annotated assignment, to closer match the order of any unions written in the type annotation. + builder.infer(return_ty, call_expression_tcx).ok()?; + + // Otherwise, build the specialization again after inferring the type context. + let specialization = builder.build(generic_context, *self.call_expression_tcx); + let return_ty = return_ty.apply_specialization(self.db, specialization); + + Some((Some(specialization), return_ty)) + }; + + (self.specialization, self.return_ty) = + try_infer_tcx().unwrap_or((Some(isolated_specialization), isolated_return_ty)); } fn check_argument_type( @@ -2826,8 +2845,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } - fn finish(self) -> (InferableTypeVars<'db, 'db>, Option>) { - (self.inferable_typevars, self.specialization) + fn finish( + self, + ) -> ( + InferableTypeVars<'db, 'db>, + Option>, + Type<'db>, + ) { + (self.inferable_typevars, self.specialization, self.return_ty) } } @@ -2985,18 +3010,16 @@ impl<'db> Binding<'db> { &self.argument_matches, &mut self.parameter_tys, call_expression_tcx, + self.return_ty, &mut self.errors, ); // If this overload is generic, first see if we can infer a specialization of the function // from the arguments that were passed in. checker.infer_specialization(); - checker.check_argument_types(); - (self.inferable_typevars, self.specialization) = checker.finish(); - if let Some(specialization) = self.specialization { - self.return_ty = self.return_ty.apply_specialization(db, specialization); - } + + (self.inferable_typevars, self.specialization, self.return_ty) = checker.finish(); } pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 20e1ab127c..456e5b237f 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> { let tcx = tcx_specialization.and_then(|specialization| { specialization.get(self.db, variable.bound_typevar) }); + ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx))); } @@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> { pub(crate) fn infer( &mut self, formal: Type<'db>, - mut actual: Type<'db>, + actual: Type<'db>, ) -> Result<(), SpecializationError<'db>> { if formal == actual { return Ok(()); @@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } - // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`. - // So, here we remove the union elements that are not related to `formal`. - actual = actual.filter_disjoint_elements(self.db, formal, self.inferable); + // Remove the union elements that are not related to `formal`. + // + // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` + // to `int`. + let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable); match (formal, actual) { // TODO: We haven't implemented a full unification solver yet. If typevars appear in diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 52a8119465..f2c256f304 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> { .and_then(|ty| ty.known_specialization(db, known_class)) } - pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self { + pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self { Self { annotation: self.annotation.map(f), } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 1053fecef2..8ea2ca9988 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5890,6 +5890,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { parenthesized: _, } = tuple; + // TODO: Use the list of inferable typevars from the generic context of tuple. + let inferable = InferableTypeVars::None; + + // Remove any union elements of that are unrelated to the tuple type. + let tcx = tcx.map(|annotation| { + annotation.filter_disjoint_elements( + self.db(), + KnownClass::Tuple.to_instance(self.db()), + inferable, + ) + }); + let annotated_tuple = tcx .known_specialization(self.db(), KnownClass::Tuple) .and_then(|specialization| { @@ -5955,7 +5967,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = dict; // Validate `TypedDict` dictionary literal assignments. - if let Some(typed_dict) = tcx.annotation.and_then(Type::as_typed_dict) + if let Some(tcx) = tcx.annotation + && let Some(typed_dict) = tcx + .filter_union(self.db(), Type::is_typed_dict) + .as_typed_dict() && let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict) { return ty; @@ -6038,9 +6053,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // TODO: Use the list of inferable typevars from the generic context of the collection // class. let inferable = InferableTypeVars::None; - let tcx = tcx.map_annotation(|annotation| { - // Remove any union elements of `annotation` that are not related to `collection_ty`. - // e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list` + + // Remove any union elements of that are unrelated to the collection type. + // + // For example, we only want the `list[int]` from `annotation: list[int] | None` if + // `collection_ty` is `list`. + let tcx = tcx.map(|annotation| { let collection_ty = collection_class.to_instance(self.db()); annotation.filter_disjoint_elements(self.db(), collection_ty, inferable) });