From 342b2665db8a540d6410a6c13a94cc78742a1152 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Tue, 17 Jun 2025 18:07:46 +0900 Subject: [PATCH] [ty] basic narrowing on attribute and subscript expressions (#17643) ## Summary This PR closes astral-sh/ty#164. This PR introduces a basic type narrowing mechanism for attribute/subscript expressions. Member accesses, int literal subscripts, string literal subscripts are supported (same as mypy and pyright). ## Test Plan New test cases are added to `mdtest/narrow/complex_target.md`. --------- Co-authored-by: David Peter --- .../resources/mdtest/attributes.md | 3 +- .../resources/mdtest/narrow/assignment.md | 6 + .../resources/mdtest/narrow/complex_target.md | 224 ++++++++++++ .../mdtest/narrow/conditionals/nested.md | 38 +- .../resources/mdtest/narrow/type_guards.md | 6 +- .../resources/primer/bad.txt | 1 + .../resources/primer/good.txt | 1 - crates/ty_python_semantic/src/place.rs | 5 +- .../ty_python_semantic/src/semantic_index.rs | 4 +- .../src/semantic_index/builder.rs | 23 +- .../semantic_index/narrowing_constraints.rs | 2 + .../src/semantic_index/place.rs | 279 ++++++++------- .../src/semantic_index/use_def.rs | 47 ++- crates/ty_python_semantic/src/types/infer.rs | 331 ++++++++++++------ crates/ty_python_semantic/src/types/narrow.rs | 96 ++--- 15 files changed, 739 insertions(+), 327 deletions(-) create mode 100644 crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index 71288e5109..53643da8ec 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -751,7 +751,8 @@ reveal_type(C.pure_class_variable) # revealed: Unknown # and the assignment is properly attributed to the class method. # error: [invalid-attribute-access] "Cannot assign to instance attribute `pure_class_variable` from the class object ``" C.pure_class_variable = "overwritten on class" - +# TODO: should be no error +# error: [unresolved-attribute] "Attribute `pure_class_variable` can only be accessed on instances, not on the class object `` itself." reveal_type(C.pure_class_variable) # revealed: Literal["overwritten on class"] c_instance = C() diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md b/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md index 73d676a2a3..d5a59ad275 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md @@ -87,6 +87,12 @@ class _: reveal_type(a.y) # revealed: Unknown | None reveal_type(a.z) # revealed: Unknown | None +a = A() +# error: [unresolved-attribute] +a.dynamically_added = 0 +# error: [unresolved-attribute] +reveal_type(a.dynamically_added) # revealed: Literal[0] + # error: [unresolved-reference] does.nt.exist = 0 # error: [unresolved-reference] diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md b/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md new file mode 100644 index 0000000000..c74f24a0aa --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/narrow/complex_target.md @@ -0,0 +1,224 @@ +# Narrowing for complex targets (attribute expressions, subscripts) + +We support type narrowing for attributes and subscripts. + +## Attribute narrowing + +### Basic + +```py +from ty_extensions import Unknown + +class C: + x: int | None = None + +c = C() + +reveal_type(c.x) # revealed: int | None + +if c.x is not None: + reveal_type(c.x) # revealed: int +else: + reveal_type(c.x) # revealed: None + +if c.x is not None: + c.x = None + +reveal_type(c.x) # revealed: None + +c = C() + +if c.x is None: + c.x = 1 + +reveal_type(c.x) # revealed: int + +class _: + reveal_type(c.x) # revealed: int + +c = C() + +class _: + if c.x is None: + c.x = 1 + reveal_type(c.x) # revealed: int + +# TODO: should be `int` +reveal_type(c.x) # revealed: int | None + +class D: + x = None + +def unknown() -> Unknown: + return 1 + +d = D() +reveal_type(d.x) # revealed: Unknown | None +d.x = 1 +reveal_type(d.x) # revealed: Literal[1] +d.x = unknown() +reveal_type(d.x) # revealed: Unknown +``` + +Narrowing can be "reset" by assigning to the attribute: + +```py +c = C() + +if c.x is None: + reveal_type(c.x) # revealed: None + c.x = 1 + reveal_type(c.x) # revealed: Literal[1] + c.x = None + reveal_type(c.x) # revealed: None + +reveal_type(c.x) # revealed: int | None +``` + +Narrowing can also be "reset" by assigning to the object: + +```py +c = C() + +if c.x is None: + reveal_type(c.x) # revealed: None + c = C() + reveal_type(c.x) # revealed: int | None + +reveal_type(c.x) # revealed: int | None +``` + +### Multiple predicates + +```py +class C: + value: str | None + +def foo(c: C): + if c.value and len(c.value): + reveal_type(c.value) # revealed: str & ~AlwaysFalsy + + # error: [invalid-argument-type] "Argument to function `len` is incorrect: Expected `Sized`, found `str | None`" + if len(c.value) and c.value: + reveal_type(c.value) # revealed: str & ~AlwaysFalsy + + if c.value is None or not len(c.value): + reveal_type(c.value) # revealed: str | None + else: # c.value is not None and len(c.value) + # TODO: should be # `str & ~AlwaysFalsy` + reveal_type(c.value) # revealed: str +``` + +### Generic class + +```toml +[environment] +python-version = "3.12" +``` + +```py +class C[T]: + x: T + y: T + + def __init__(self, x: T): + self.x = x + self.y = x + +def f(a: int | None): + c = C(a) + reveal_type(c.x) # revealed: int | None + reveal_type(c.y) # revealed: int | None + if c.x is not None: + reveal_type(c.x) # revealed: int + # In this case, it may seem like we can narrow it down to `int`, + # but different values ​​may be reassigned to `x` and `y` in another place. + reveal_type(c.y) # revealed: int | None + +def g[T](c: C[T]): + reveal_type(c.x) # revealed: T + reveal_type(c.y) # revealed: T + reveal_type(c) # revealed: C[T] + + if isinstance(c.x, int): + reveal_type(c.x) # revealed: T & int + reveal_type(c.y) # revealed: T + reveal_type(c) # revealed: C[T] + if isinstance(c.x, int) and isinstance(c.y, int): + reveal_type(c.x) # revealed: T & int + reveal_type(c.y) # revealed: T & int + # TODO: Probably better if inferred as `C[T & int]` (mypy and pyright don't support this) + reveal_type(c) # revealed: C[T] +``` + +### With intermediate scopes + +```py +class C: + def __init__(self): + self.x: int | None = None + self.y: int | None = None + +c = C() +reveal_type(c.x) # revealed: int | None +if c.x is not None: + reveal_type(c.x) # revealed: int + reveal_type(c.y) # revealed: int | None + +if c.x is not None: + def _(): + reveal_type(c.x) # revealed: Unknown | int | None + +def _(): + if c.x is not None: + reveal_type(c.x) # revealed: (Unknown & ~None) | int +``` + +## Subscript narrowing + +### Number subscript + +```py +def _(t1: tuple[int | None, int | None], t2: tuple[int, int] | tuple[None, None]): + if t1[0] is not None: + reveal_type(t1[0]) # revealed: int + reveal_type(t1[1]) # revealed: int | None + + n = 0 + if t1[n] is not None: + # Non-literal subscript narrowing are currently not supported, as well as mypy, pyright + reveal_type(t1[0]) # revealed: int | None + reveal_type(t1[n]) # revealed: int | None + reveal_type(t1[1]) # revealed: int | None + + if t2[0] is not None: + # TODO: should be int + reveal_type(t2[0]) # revealed: Unknown & ~None + # TODO: should be int + reveal_type(t2[1]) # revealed: Unknown +``` + +### String subscript + +```py +def _(d: dict[str, str | None]): + if d["a"] is not None: + reveal_type(d["a"]) # revealed: str + reveal_type(d["b"]) # revealed: str | None +``` + +## Combined attribute and subscript narrowing + +```py +class C: + def __init__(self): + self.x: tuple[int | None, int | None] = (None, None) + +class D: + def __init__(self): + self.c: tuple[C] | None = None + +d = D() +if d.c is not None and d.c[0].x[0] is not None: + reveal_type(d.c[0].x[0]) # revealed: int +``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md index b3b077f1bc..033ca89d3e 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md @@ -135,7 +135,7 @@ class _: class _3: reveal_type(a) # revealed: A # TODO: should be `D | None` - reveal_type(a.b.c1.d) # revealed: D + reveal_type(a.b.c1.d) # revealed: Unknown a.b.c1 = C() a.b.c1.d = D() @@ -173,12 +173,10 @@ def f(x: str | None): reveal_type(g) # revealed: str if a.x is not None: - # TODO(#17643): should be `Unknown | str` - reveal_type(a.x) # revealed: Unknown | str | None + reveal_type(a.x) # revealed: (Unknown & ~None) | str if l[0] is not None: - # TODO(#17643): should be `str` - reveal_type(l[0]) # revealed: str | None + reveal_type(l[0]) # revealed: str class C: if x is not None: @@ -191,12 +189,10 @@ def f(x: str | None): reveal_type(g) # revealed: str if a.x is not None: - # TODO(#17643): should be `Unknown | str` - reveal_type(a.x) # revealed: Unknown | str | None + reveal_type(a.x) # revealed: (Unknown & ~None) | str if l[0] is not None: - # TODO(#17643): should be `str` - reveal_type(l[0]) # revealed: str | None + reveal_type(l[0]) # revealed: str # TODO: should be str # This could be fixed if we supported narrowing with if clauses in comprehensions. @@ -241,22 +237,18 @@ def f(x: str | None): reveal_type(a.x) # revealed: Unknown | str | None class D: - # TODO(#17643): should be `Unknown | str` - reveal_type(a.x) # revealed: Unknown | str | None + reveal_type(a.x) # revealed: (Unknown & ~None) | str - # TODO(#17643): should be `Unknown | str` - [reveal_type(a.x) for _ in range(1)] # revealed: Unknown | str | None + [reveal_type(a.x) for _ in range(1)] # revealed: (Unknown & ~None) | str if l[0] is not None: def _(): reveal_type(l[0]) # revealed: str | None class D: - # TODO(#17643): should be `str` - reveal_type(l[0]) # revealed: str | None + reveal_type(l[0]) # revealed: str - # TODO(#17643): should be `str` - [reveal_type(l[0]) for _ in range(1)] # revealed: str | None + [reveal_type(l[0]) for _ in range(1)] # revealed: str ``` ### Narrowing constraints introduced in multiple scopes @@ -299,24 +291,20 @@ def f(x: str | Literal[1] | None): if a.x is not None: def _(): if a.x != 1: - # TODO(#17643): should be `Unknown | str | None` - reveal_type(a.x) # revealed: Unknown | str | Literal[1] | None + reveal_type(a.x) # revealed: (Unknown & ~Literal[1]) | str | None class D: if a.x != 1: - # TODO(#17643): should be `Unknown | str` - reveal_type(a.x) # revealed: Unknown | str | Literal[1] | None + reveal_type(a.x) # revealed: (Unknown & ~Literal[1] & ~None) | str if l[0] is not None: def _(): if l[0] != 1: - # TODO(#17643): should be `str | None` - reveal_type(l[0]) # revealed: str | Literal[1] | None + reveal_type(l[0]) # revealed: str | None class D: if l[0] != 1: - # TODO(#17643): should be `str` - reveal_type(l[0]) # revealed: str | Literal[1] | None + reveal_type(l[0]) # revealed: str ``` ### Narrowing constraints with bindings in class scope, and nested scopes diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md index c65a3b22c6..da19b56948 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type_guards.md @@ -220,8 +220,7 @@ def _(a: tuple[str, int] | tuple[int, str], c: C[Any]): if reveal_type(is_int(a[0])): # revealed: TypeIs[int @ a[0]] # TODO: Should be `tuple[int, str]` reveal_type(a) # revealed: tuple[str, int] | tuple[int, str] - # TODO: Should be `int` - reveal_type(a[0]) # revealed: Unknown + reveal_type(a[0]) # revealed: Unknown & int # TODO: Should be `TypeGuard[str @ c.v]` if reveal_type(guard_str(c.v)): # revealed: @Todo(`TypeGuard[]` special form) @@ -231,8 +230,7 @@ def _(a: tuple[str, int] | tuple[int, str], c: C[Any]): if reveal_type(is_int(c.v)): # revealed: TypeIs[int @ c.v] reveal_type(c) # revealed: C[Any] - # TODO: Should be `int` - reveal_type(c.v) # revealed: Any + reveal_type(c.v) # revealed: Any & int ``` Indirect usage is supported within the same scope: diff --git a/crates/ty_python_semantic/resources/primer/bad.txt b/crates/ty_python_semantic/resources/primer/bad.txt index b3d6aa33b1..f92289d6ec 100644 --- a/crates/ty_python_semantic/resources/primer/bad.txt +++ b/crates/ty_python_semantic/resources/primer/bad.txt @@ -17,4 +17,5 @@ setuptools # vendors packaging, see above spack # slow, success, but mypy-primer hangs processing the output spark # too many iterations steam.py # hangs (single threaded) +tornado # bad use-def map (https://github.com/astral-sh/ty/issues/365) xarray # too many iterations diff --git a/crates/ty_python_semantic/resources/primer/good.txt b/crates/ty_python_semantic/resources/primer/good.txt index 8d5dc64438..be08556497 100644 --- a/crates/ty_python_semantic/resources/primer/good.txt +++ b/crates/ty_python_semantic/resources/primer/good.txt @@ -110,7 +110,6 @@ stone strawberry streamlit sympy -tornado trio twine typeshed-stats diff --git a/crates/ty_python_semantic/src/place.rs b/crates/ty_python_semantic/src/place.rs index 9129edee8c..33d89ccd77 100644 --- a/crates/ty_python_semantic/src/place.rs +++ b/crates/ty_python_semantic/src/place.rs @@ -661,6 +661,7 @@ fn place_by_id<'db>( // See mdtest/known_constants.md#user-defined-type_checking for details. let is_considered_non_modifiable = place_table(db, scope) .place_expr(place_id) + .expr .is_name_and(|name| matches!(name, "__slots__" | "TYPE_CHECKING")); if scope.file(db).is_stub(db.upcast()) { @@ -1124,8 +1125,8 @@ mod implicit_globals { module_type_symbol_table .places() - .filter(|symbol| symbol.is_declared() && symbol.is_name()) - .map(semantic_index::place::PlaceExpr::expect_name) + .filter(|place| place.is_declared() && place.is_name()) + .map(semantic_index::place::PlaceExprWithFlags::expect_name) .filter(|symbol_name| { !matches!(&***symbol_name, "__dict__" | "__getattr__" | "__init__") }) diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index c39a5b12df..b584d61b11 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -37,8 +37,8 @@ mod reachability_constraints; mod use_def; pub(crate) use self::use_def::{ - BindingWithConstraints, BindingWithConstraintsIterator, DeclarationWithConstraint, - DeclarationsIterator, + ApplicableConstraints, BindingWithConstraints, BindingWithConstraintsIterator, + DeclarationWithConstraint, DeclarationsIterator, }; type PlaceSet = hashbrown::HashMap; diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 3e608a5c6e..1fc3e3dd1f 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -33,7 +33,7 @@ use crate::semantic_index::definition::{ use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::place::{ FileScopeId, NodeWithScopeKey, NodeWithScopeKind, NodeWithScopeRef, PlaceExpr, - PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId, + PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId, }; use crate::semantic_index::predicate::{ PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, ScopedPredicateId, @@ -295,6 +295,15 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { // If the scope that we just popped off is an eager scope, we need to "lock" our view of // which bindings reach each of the uses in the scope. Loop through each enclosing scope, // looking for any that bind each place. + // TODO: Bindings in eager nested scopes also need to be recorded. For example: + // ```python + // class C: + // x: int | None = None + // c = C() + // class _: + // c.x = 1 + // reveal_type(c.x) # revealed: Literal[1] + // ``` for enclosing_scope_info in self.scope_stack.iter().rev() { let enclosing_scope_id = enclosing_scope_info.file_scope_id; let enclosing_scope_kind = self.scopes[enclosing_scope_id].kind(); @@ -306,7 +315,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { // it may refer to the enclosing scope bindings // so we also need to snapshot the bindings of the enclosing scope. - let Some(enclosing_place_id) = enclosing_place_table.place_id_by_expr(nested_place) + let Some(enclosing_place_id) = + enclosing_place_table.place_id_by_expr(&nested_place.expr) else { continue; }; @@ -388,7 +398,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { /// Add a place to the place table and the use-def map. /// Return the [`ScopedPlaceId`] that uniquely identifies the place in both. - fn add_place(&mut self, place_expr: PlaceExpr) -> ScopedPlaceId { + fn add_place(&mut self, place_expr: PlaceExprWithFlags) -> ScopedPlaceId { let (place_id, added) = self.current_place_table().add_place(place_expr); if added { self.current_use_def_map_mut().add_place(place_id); @@ -1863,7 +1873,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { walk_stmt(self, stmt); for target in targets { if let Ok(target) = PlaceExpr::try_from(target) { - let place_id = self.add_place(target); + let place_id = self.add_place(PlaceExprWithFlags::new(target)); self.current_place_table().mark_place_used(place_id); self.delete_binding(place_id); } @@ -1898,7 +1908,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { ast::Expr::Name(ast::ExprName { ctx, .. }) | ast::Expr::Attribute(ast::ExprAttribute { ctx, .. }) | ast::Expr::Subscript(ast::ExprSubscript { ctx, .. }) => { - if let Ok(mut place_expr) = PlaceExpr::try_from(expr) { + if let Ok(place_expr) = PlaceExpr::try_from(expr) { + let mut place_expr = PlaceExprWithFlags::new(place_expr); if self.is_method_of_class().is_some() && place_expr.is_instance_attribute_candidate() { @@ -1906,7 +1917,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // i.e. typically `self` or `cls`. let accessed_object_refers_to_first_parameter = self .current_first_parameter_name - .is_some_and(|fst| place_expr.root_name() == fst); + .is_some_and(|fst| place_expr.expr.root_name() == fst); if accessed_object_refers_to_first_parameter && place_expr.is_member() { place_expr.mark_instance_attribute(); diff --git a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs index fa6280ead6..48297e1da6 100644 --- a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs @@ -30,6 +30,7 @@ use crate::list::{List, ListBuilder, ListSetReverseIterator, ListStorage}; use crate::semantic_index::ast_ids::ScopedUseId; +use crate::semantic_index::place::FileScopeId; use crate::semantic_index::predicate::ScopedPredicateId; /// A narrowing constraint associated with a live binding. @@ -42,6 +43,7 @@ pub(crate) type ScopedNarrowingConstraint = List, - flags: PlaceFlags, } impl std::fmt::Display for PlaceExpr { @@ -151,23 +148,27 @@ impl TryFrom<&ast::Expr> for PlaceExpr { } } +impl TryFrom> for PlaceExpr { + type Error = (); + + fn try_from(expr: ast::ExprRef) -> Result { + match expr { + ast::ExprRef::Name(name) => Ok(PlaceExpr::name(name.id.clone())), + ast::ExprRef::Attribute(attr) => PlaceExpr::try_from(attr), + ast::ExprRef::Subscript(subscript) => PlaceExpr::try_from(subscript), + _ => Err(()), + } + } +} + impl PlaceExpr { - pub(super) fn name(name: Name) -> Self { + pub(crate) fn name(name: Name) -> Self { Self { root_name: name, sub_segments: smallvec![], - flags: PlaceFlags::empty(), } } - fn insert_flags(&mut self, flags: PlaceFlags) { - self.flags.insert(flags); - } - - pub(super) fn mark_instance_attribute(&mut self) { - self.flags.insert(PlaceFlags::IS_INSTANCE_ATTRIBUTE); - } - pub(crate) fn root_name(&self) -> &Name { &self.root_name } @@ -191,6 +192,66 @@ impl PlaceExpr { &self.root_name } + /// Is the place just a name? + pub fn is_name(&self) -> bool { + self.sub_segments.is_empty() + } + + pub fn is_name_and(&self, f: impl FnOnce(&str) -> bool) -> bool { + self.is_name() && f(&self.root_name) + } + + /// Does the place expression have the form `.member`? + pub fn is_member(&self) -> bool { + self.sub_segments + .last() + .is_some_and(|last| last.as_member().is_some()) + } + + fn root_exprs(&self) -> RootExprs<'_> { + RootExprs { + expr_ref: self.into(), + len: self.sub_segments.len(), + } + } +} + +/// A [`PlaceExpr`] with flags, e.g. whether it is used, bound, an instance attribute, etc. +#[derive(Eq, PartialEq, Debug)] +pub struct PlaceExprWithFlags { + pub(crate) expr: PlaceExpr, + flags: PlaceFlags, +} + +impl std::fmt::Display for PlaceExprWithFlags { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.expr.fmt(f) + } +} + +impl PlaceExprWithFlags { + pub(crate) fn new(expr: PlaceExpr) -> Self { + PlaceExprWithFlags { + expr, + flags: PlaceFlags::empty(), + } + } + + fn name(name: Name) -> Self { + PlaceExprWithFlags { + expr: PlaceExpr::name(name), + flags: PlaceFlags::empty(), + } + } + + fn insert_flags(&mut self, flags: PlaceFlags) { + self.flags.insert(flags); + } + + pub(super) fn mark_instance_attribute(&mut self) { + self.flags.insert(PlaceFlags::IS_INSTANCE_ATTRIBUTE); + } + /// If the place expression has the form `.` /// (meaning it *may* be an instance attribute), /// return `Some()`. Else, return `None`. @@ -202,8 +263,8 @@ impl PlaceExpr { /// parameter of the method (i.e. `self`). To answer those questions, /// use [`Self::as_instance_attribute`]. pub(super) fn as_instance_attribute_candidate(&self) -> Option<&Name> { - if self.sub_segments.len() == 1 { - self.sub_segments[0].as_member() + if self.expr.sub_segments.len() == 1 { + self.expr.sub_segments[0].as_member() } else { None } @@ -227,6 +288,16 @@ impl PlaceExpr { self.as_instance_attribute().map(Name::as_str) == Some(name) } + /// Return `Some()` if the place expression is an instance attribute. + pub(crate) fn as_instance_attribute(&self) -> Option<&Name> { + if self.is_instance_attribute() { + debug_assert!(self.as_instance_attribute_candidate().is_some()); + self.as_instance_attribute_candidate() + } else { + None + } + } + /// Is the place an instance attribute? pub(crate) fn is_instance_attribute(&self) -> bool { let is_instance_attribute = self.flags.contains(PlaceFlags::IS_INSTANCE_ATTRIBUTE); @@ -236,14 +307,12 @@ impl PlaceExpr { is_instance_attribute } - /// Return `Some()` if the place expression is an instance attribute. - pub(crate) fn as_instance_attribute(&self) -> Option<&Name> { - if self.is_instance_attribute() { - debug_assert!(self.as_instance_attribute_candidate().is_some()); - self.as_instance_attribute_candidate() - } else { - None - } + pub(crate) fn is_name(&self) -> bool { + self.expr.is_name() + } + + pub(crate) fn is_member(&self) -> bool { + self.expr.is_member() } /// Is the place used in its containing scope? @@ -261,56 +330,58 @@ impl PlaceExpr { self.flags.contains(PlaceFlags::IS_DECLARED) } - /// Is the place just a name? - pub fn is_name(&self) -> bool { - self.sub_segments.is_empty() + pub(crate) fn as_name(&self) -> Option<&Name> { + self.expr.as_name() } - pub fn is_name_and(&self, f: impl FnOnce(&str) -> bool) -> bool { - self.is_name() && f(&self.root_name) + pub(crate) fn expect_name(&self) -> &Name { + self.expr.expect_name() } +} - /// Does the place expression have the form `.member`? - pub fn is_member(&self) -> bool { - self.sub_segments - .last() - .is_some_and(|last| last.as_member().is_some()) +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +pub struct PlaceExprRef<'a> { + pub(crate) root_name: &'a Name, + pub(crate) sub_segments: &'a [PlaceExprSubSegment], +} + +impl PartialEq for PlaceExprRef<'_> { + fn eq(&self, other: &PlaceExpr) -> bool { + self.root_name == &other.root_name && self.sub_segments == &other.sub_segments[..] } +} - pub(crate) fn segments(&self) -> PlaceSegments { - PlaceSegments { - root_name: Some(&self.root_name), - sub_segments: &self.sub_segments, - } +impl PartialEq> for PlaceExpr { + fn eq(&self, other: &PlaceExprRef<'_>) -> bool { + &self.root_name == other.root_name && &self.sub_segments[..] == other.sub_segments } +} - // TODO: Ideally this would iterate PlaceSegments instead of RootExprs, both to reduce - // allocation and to avoid having both flagged and non-flagged versions of PlaceExprs. - fn root_exprs(&self) -> RootExprs<'_> { - RootExprs { - expr: self, - len: self.sub_segments.len(), +impl<'e> From<&'e PlaceExpr> for PlaceExprRef<'e> { + fn from(expr: &'e PlaceExpr) -> Self { + PlaceExprRef { + root_name: &expr.root_name, + sub_segments: &expr.sub_segments, } } } struct RootExprs<'e> { - expr: &'e PlaceExpr, + expr_ref: PlaceExprRef<'e>, len: usize, } -impl Iterator for RootExprs<'_> { - type Item = PlaceExpr; +impl<'e> Iterator for RootExprs<'e> { + type Item = PlaceExprRef<'e>; fn next(&mut self) -> Option { if self.len == 0 { return None; } self.len -= 1; - Some(PlaceExpr { - root_name: self.expr.root_name.clone(), - sub_segments: self.expr.sub_segments[..self.len].iter().cloned().collect(), - flags: PlaceFlags::empty(), + Some(PlaceExprRef { + root_name: self.expr_ref.root_name, + sub_segments: &self.expr_ref.sub_segments[..self.len], }) } } @@ -333,41 +404,6 @@ bitflags! { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PlaceSegment<'a> { - /// A first segment of a place expression (root name), e.g. `x` in `x.y.z[0]`. - Name(&'a ast::name::Name), - Member(&'a ast::name::Name), - IntSubscript(&'a ast::Int), - StringSubscript(&'a str), -} - -#[derive(Debug, PartialEq, Eq)] -pub struct PlaceSegments<'a> { - root_name: Option<&'a ast::name::Name>, - sub_segments: &'a [PlaceExprSubSegment], -} - -impl<'a> Iterator for PlaceSegments<'a> { - type Item = PlaceSegment<'a>; - - fn next(&mut self) -> Option { - if let Some(name) = self.root_name.take() { - return Some(PlaceSegment::Name(name)); - } - if self.sub_segments.is_empty() { - return None; - } - let segment = &self.sub_segments[0]; - self.sub_segments = &self.sub_segments[1..]; - Some(match segment { - PlaceExprSubSegment::Member(name) => PlaceSegment::Member(name), - PlaceExprSubSegment::IntSubscript(int) => PlaceSegment::IntSubscript(int), - PlaceExprSubSegment::StringSubscript(string) => PlaceSegment::StringSubscript(string), - }) - } -} - /// ID that uniquely identifies a place in a file. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct FilePlaceId { @@ -575,7 +611,7 @@ impl ScopeKind { #[derive(Default, salsa::Update)] pub struct PlaceTable { /// The place expressions in this scope. - places: IndexVec, + places: IndexVec, /// The set of places. place_set: PlaceSet, @@ -586,7 +622,7 @@ impl PlaceTable { self.places.shrink_to_fit(); } - pub(crate) fn place_expr(&self, place_id: impl Into) -> &PlaceExpr { + pub(crate) fn place_expr(&self, place_id: impl Into) -> &PlaceExprWithFlags { &self.places[place_id.into()] } @@ -594,10 +630,10 @@ impl PlaceTable { pub(crate) fn root_place_exprs( &self, place_expr: &PlaceExpr, - ) -> impl Iterator { + ) -> impl Iterator { place_expr .root_exprs() - .filter_map(|place_expr| self.place_by_expr(&place_expr)) + .filter_map(|place_expr| self.place_by_expr(place_expr)) } #[expect(unused)] @@ -605,11 +641,11 @@ impl PlaceTable { self.places.indices() } - pub fn places(&self) -> impl Iterator { + pub fn places(&self) -> impl Iterator { self.places.iter() } - pub fn symbols(&self) -> impl Iterator { + pub fn symbols(&self) -> impl Iterator { self.places().filter(|place_expr| place_expr.is_name()) } @@ -620,19 +656,16 @@ impl PlaceTable { /// Returns the place named `name`. #[allow(unused)] // used in tests - pub(crate) fn place_by_name(&self, name: &str) -> Option<&PlaceExpr> { + pub(crate) fn place_by_name(&self, name: &str) -> Option<&PlaceExprWithFlags> { let id = self.place_id_by_name(name)?; Some(self.place_expr(id)) } - /// Returns the flagged place by the unflagged place expression. - /// - /// TODO: Ideally this would take a [`PlaceSegments`] instead of [`PlaceExpr`], to avoid the - /// awkward distinction between "flagged" (canonical) and unflagged [`PlaceExpr`]; in that - /// world, we would only create [`PlaceExpr`] in semantic indexing; in type inference we'd - /// create [`PlaceSegments`] if we need to look up a [`PlaceExpr`]. The [`PlaceTable`] would - /// need to gain the ability to hash and look up by a [`PlaceSegments`]. - pub(crate) fn place_by_expr(&self, place_expr: &PlaceExpr) -> Option<&PlaceExpr> { + /// Returns the flagged place. + pub(crate) fn place_by_expr<'e>( + &self, + place_expr: impl Into>, + ) -> Option<&PlaceExprWithFlags> { let id = self.place_id_by_expr(place_expr)?; Some(self.place_expr(id)) } @@ -650,12 +683,16 @@ impl PlaceTable { } /// Returns the [`ScopedPlaceId`] of the place expression. - pub(crate) fn place_id_by_expr(&self, place_expr: &PlaceExpr) -> Option { + pub(crate) fn place_id_by_expr<'e>( + &self, + place_expr: impl Into>, + ) -> Option { + let place_expr = place_expr.into(); let (id, ()) = self .place_set .raw_entry() .from_hash(Self::hash_place_expr(place_expr), |id| { - self.place_expr(*id).segments() == place_expr.segments() + self.place_expr(*id).expr == place_expr })?; Some(*id) @@ -673,10 +710,12 @@ impl PlaceTable { hasher.finish() } - fn hash_place_expr(place_expr: &PlaceExpr) -> u64 { + fn hash_place_expr<'e>(place_expr: impl Into>) -> u64 { + let place_expr = place_expr.into(); + let mut hasher = FxHasher::default(); - place_expr.root_name().as_str().hash(&mut hasher); - for segment in &place_expr.sub_segments { + place_expr.root_name.as_str().hash(&mut hasher); + for segment in place_expr.sub_segments { match segment { PlaceExprSubSegment::Member(name) => name.hash(&mut hasher), PlaceExprSubSegment::IntSubscript(int) => int.hash(&mut hasher), @@ -725,11 +764,11 @@ impl PlaceTableBuilder { match entry { RawEntryMut::Occupied(entry) => (*entry.key(), false), RawEntryMut::Vacant(entry) => { - let symbol = PlaceExpr::name(name); + let symbol = PlaceExprWithFlags::name(name); let id = self.table.places.push(symbol); entry.insert_with_hasher(hash, id, (), |id| { - PlaceTable::hash_place_expr(&self.table.places[*id]) + PlaceTable::hash_place_expr(&self.table.places[*id].expr) }); let new_id = self.associated_place_ids.push(vec![]); debug_assert_eq!(new_id, id); @@ -738,23 +777,25 @@ impl PlaceTableBuilder { } } - pub(super) fn add_place(&mut self, place_expr: PlaceExpr) -> (ScopedPlaceId, bool) { - let hash = PlaceTable::hash_place_expr(&place_expr); - let entry = self.table.place_set.raw_entry_mut().from_hash(hash, |id| { - self.table.places[*id].segments() == place_expr.segments() - }); + pub(super) fn add_place(&mut self, place_expr: PlaceExprWithFlags) -> (ScopedPlaceId, bool) { + let hash = PlaceTable::hash_place_expr(&place_expr.expr); + let entry = self + .table + .place_set + .raw_entry_mut() + .from_hash(hash, |id| self.table.places[*id].expr == place_expr.expr); match entry { RawEntryMut::Occupied(entry) => (*entry.key(), false), RawEntryMut::Vacant(entry) => { let id = self.table.places.push(place_expr); entry.insert_with_hasher(hash, id, (), |id| { - PlaceTable::hash_place_expr(&self.table.places[*id]) + PlaceTable::hash_place_expr(&self.table.places[*id].expr) }); let new_id = self.associated_place_ids.push(vec![]); debug_assert_eq!(new_id, id); - for root in self.table.places[id].root_exprs() { - if let Some(root_id) = self.table.place_id_by_expr(&root) { + for root in self.table.places[id].expr.root_exprs() { + if let Some(root_id) = self.table.place_id_by_expr(root) { self.associated_place_ids[root_id].push(id); } } @@ -775,7 +816,7 @@ impl PlaceTableBuilder { self.table.places[id].insert_flags(PlaceFlags::IS_USED); } - pub(super) fn places(&self) -> impl Iterator { + pub(super) fn places(&self) -> impl Iterator { self.table.places() } @@ -783,7 +824,7 @@ impl PlaceTableBuilder { self.table.place_id_by_expr(place_expr) } - pub(super) fn place_expr(&self, place_id: impl Into) -> &PlaceExpr { + pub(super) fn place_expr(&self, place_id: impl Into) -> &PlaceExprWithFlags { self.table.place_expr(place_id) } diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index cb57055ab8..8ac7f91811 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -237,19 +237,21 @@ use self::place_state::{ LiveDeclarationsIterator, PlaceState, ScopedDefinitionId, }; use crate::node_key::NodeKey; -use crate::semantic_index::EagerSnapshotResult; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::{Definition, DefinitionState}; use crate::semantic_index::narrowing_constraints::{ ConstraintKey, NarrowingConstraints, NarrowingConstraintsBuilder, NarrowingConstraintsIterator, }; -use crate::semantic_index::place::{FileScopeId, PlaceExpr, ScopeKind, ScopedPlaceId}; +use crate::semantic_index::place::{ + FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId, +}; use crate::semantic_index::predicate::{ Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate, }; use crate::semantic_index::reachability_constraints::{ ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId, }; +use crate::semantic_index::{EagerSnapshotResult, SemanticIndex}; use crate::types::{IntersectionBuilder, Truthiness, Type, infer_narrowing_constraint}; mod place_state; @@ -320,6 +322,11 @@ pub(crate) struct UseDefMap<'db> { end_of_scope_reachability: ScopedReachabilityConstraintId, } +pub(crate) enum ApplicableConstraints<'map, 'db> { + UnboundBinding(ConstraintsIterator<'map, 'db>), + ConstrainedBindings(BindingWithConstraintsIterator<'map, 'db>), +} + impl<'db> UseDefMap<'db> { pub(crate) fn bindings_at_use( &self, @@ -328,19 +335,33 @@ impl<'db> UseDefMap<'db> { self.bindings_iterator(&self.bindings_by_use[use_id]) } - pub(crate) fn narrowing_constraints_at_use( + pub(crate) fn applicable_constraints( &self, constraint_key: ConstraintKey, - ) -> ConstraintsIterator<'_, 'db> { - let constraint = match constraint_key { - ConstraintKey::NarrowingConstraint(constraint) => constraint, - ConstraintKey::UseId(use_id) => { - self.bindings_by_use[use_id].unbound_narrowing_constraint() + enclosing_scope: FileScopeId, + expr: &PlaceExpr, + index: &'db SemanticIndex, + ) -> ApplicableConstraints<'_, 'db> { + match constraint_key { + ConstraintKey::NarrowingConstraint(constraint) => { + ApplicableConstraints::UnboundBinding(ConstraintsIterator { + predicates: &self.predicates, + constraint_ids: self.narrowing_constraints.iter_predicates(constraint), + }) + } + ConstraintKey::EagerNestedScope(nested_scope) => { + let EagerSnapshotResult::FoundBindings(bindings) = + index.eager_snapshot(enclosing_scope, expr, nested_scope) + else { + unreachable!( + "The result of `SemanticIndex::eager_snapshot` must be `FoundBindings`" + ) + }; + ApplicableConstraints::ConstrainedBindings(bindings) + } + ConstraintKey::UseId(use_id) => { + ApplicableConstraints::ConstrainedBindings(self.bindings_at_use(use_id)) } - }; - ConstraintsIterator { - predicates: &self.predicates, - constraint_ids: self.narrowing_constraints.iter_predicates(constraint), } } @@ -884,7 +905,7 @@ impl<'db> UseDefMapBuilder<'db> { &mut self, enclosing_place: ScopedPlaceId, scope: ScopeKind, - enclosing_place_expr: &PlaceExpr, + enclosing_place_expr: &PlaceExprWithFlags, ) -> ScopedEagerSnapshotId { // Names bound in class scopes are never visible to nested scopes (but attributes/subscripts are visible), // so we never need to save eager scope bindings in a class scope. diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 8e1752e7fb..39f0227d4f 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -60,7 +60,7 @@ use crate::semantic_index::ast_ids::{ }; use crate::semantic_index::definition::{ AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind, - Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, + Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; @@ -68,7 +68,9 @@ use crate::semantic_index::narrowing_constraints::ConstraintKey; use crate::semantic_index::place::{ FileScopeId, NodeWithScopeKind, NodeWithScopeRef, PlaceExpr, ScopeId, ScopeKind, ScopedPlaceId, }; -use crate::semantic_index::{EagerSnapshotResult, SemanticIndex, place_table, semantic_index}; +use crate::semantic_index::{ + ApplicableConstraints, EagerSnapshotResult, SemanticIndex, place_table, semantic_index, +}; use crate::types::call::{ Argument, Binding, Bindings, CallArgumentTypes, CallArguments, CallError, }; @@ -746,6 +748,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .expression_type(expr.scoped_expression_id(self.db(), self.scope())) } + fn try_expression_type(&self, expr: &ast::Expr) -> Option> { + self.types + .try_expression_type(expr.scoped_expression_id(self.db(), self.scope())) + } + /// Get the type of an expression from any scope in the same file. /// /// If the expression is in the current scope, and we are inferring the entire scope, just look @@ -1510,13 +1517,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let global_use_def_map = self.index.use_def_map(FileScopeId::global()); let place_id = binding.place(self.db()); - let expr = place_table.place_expr(place_id); + let place = place_table.place_expr(place_id); let skip_non_global_scopes = self.skip_non_global_scopes(file_scope_id, place_id); let declarations = if skip_non_global_scopes { match self .index .place_table(FileScopeId::global()) - .place_id_by_expr(expr) + .place_id_by_expr(&place.expr) { Some(id) => global_use_def_map.public_declarations(id), // This case is a syntax error (load before global declaration) but ignore that here @@ -1527,18 +1534,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; let declared_ty = place_from_declarations(self.db(), declarations) - .and_then(|place| { - Ok(if matches!(place.place, Place::Type(_, Boundness::Bound)) { - place - } else if skip_non_global_scopes - || self.scope().file_scope_id(self.db()).is_global() - { - let module_type_declarations = - module_type_implicit_global_declaration(self.db(), expr)?; - place.or_fall_back_to(self.db(), || module_type_declarations) - } else { - place - }) + .and_then(|place_and_quals| { + Ok( + if matches!(place_and_quals.place, Place::Type(_, Boundness::Bound)) { + place_and_quals + } else if skip_non_global_scopes + || self.scope().file_scope_id(self.db()).is_global() + { + let module_type_declarations = + module_type_implicit_global_declaration(self.db(), &place.expr)?; + place_and_quals.or_fall_back_to(self.db(), || module_type_declarations) + } else { + place_and_quals + }, + ) }) .map( |PlaceAndQualifiers { @@ -1576,10 +1585,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) .unwrap_or_else(|(ty, conflicting)| { // TODO point out the conflicting declarations in the diagnostic? - let expr = place_table.place_expr(binding.place(db)); + let place = place_table.place_expr(binding.place(db)); if let Some(builder) = self.context.report_lint(&CONFLICTING_DECLARATIONS, node) { builder.into_diagnostic(format_args!( - "Conflicting declared types for `{expr}`: {}", + "Conflicting declared types for `{place}`: {}", conflicting.display(db) )); } @@ -1590,6 +1599,54 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // allow declarations to override inference in case of invalid assignment bound_ty = declared_ty; } + // In the following cases, the bound type may not be the same as the RHS value type. + if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { + let value_ty = self + .try_expression_type(value) + .unwrap_or_else(|| self.infer_maybe_standalone_expression(value)); + // If the member is a data descriptor, the RHS value may differ from the value actually assigned. + if value_ty + .class_member(db, attr.id.clone()) + .place + .ignore_possibly_unbound() + .is_some_and(|ty| ty.may_be_data_descriptor(db)) + { + bound_ty = declared_ty; + } + } else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node { + let value_ty = self + .try_expression_type(value) + .unwrap_or_else(|| self.infer_expression(value)); + // Arbitrary `__getitem__`/`__setitem__` methods on a class do not + // necessarily guarantee that the passed-in value for `__setitem__` is stored and + // can be retrieved unmodified via `__getitem__`. Therefore, we currently only + // perform assignment-based narrowing on a few built-in classes (`list`, `dict`, + // `bytesarray`, `TypedDict` and `collections` types) where we are confident that + // this kind of narrowing can be performed soundly. This is the same approach as + // pyright. TODO: Other standard library classes may also be considered safe. Also, + // subclasses of these safe classes that do not override `__getitem__/__setitem__` + // may be considered safe. + let safe_mutable_classes = [ + KnownClass::List.to_instance(db), + KnownClass::Dict.to_instance(db), + KnownClass::Bytearray.to_instance(db), + KnownClass::DefaultDict.to_instance(db), + SpecialFormType::ChainMap.instance_fallback(db), + SpecialFormType::Counter.instance_fallback(db), + SpecialFormType::Deque.instance_fallback(db), + SpecialFormType::OrderedDict.instance_fallback(db), + SpecialFormType::TypedDict.instance_fallback(db), + ]; + if safe_mutable_classes.iter().all(|safe_mutable_class| { + !value_ty.is_equivalent_to(db, *safe_mutable_class) + && value_ty + .generic_origin(db) + .zip(safe_mutable_class.generic_origin(db)) + .is_none_or(|(l, r)| l != r) + }) { + bound_ty = declared_ty; + } + } self.types.bindings.insert(binding, bound_ty); } @@ -1624,9 +1681,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Fallback to bindings declared on `types.ModuleType` if it's a global symbol let scope = self.scope().file_scope_id(self.db()); let place_table = self.index.place_table(scope); - let expr = place_table.place_expr(declaration.place(self.db())); - if scope.is_global() && expr.is_name() { - module_type_implicit_global_symbol(self.db(), expr.expect_name()) + let place = place_table.place_expr(declaration.place(self.db())); + if scope.is_global() && place.is_name() { + module_type_implicit_global_symbol(self.db(), place.expect_name()) } else { Place::Unbound.into() } @@ -1677,9 +1734,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let file_scope_id = self.scope().file_scope_id(self.db()); if file_scope_id.is_global() { let place_table = self.index.place_table(file_scope_id); - let expr = place_table.place_expr(definition.place(self.db())); + let place = place_table.place_expr(definition.place(self.db())); if let Some(module_type_implicit_declaration) = - module_type_implicit_global_declaration(self.db(), expr) + module_type_implicit_global_declaration(self.db(), &place.expr) .ok() .and_then(|place| place.place.ignore_possibly_unbound()) { @@ -1691,11 +1748,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.context.report_lint(&INVALID_DECLARATION, node) { let mut diagnostic = builder.into_diagnostic(format_args!( - "Cannot shadow implicit global attribute `{expr}` with declaration of type `{}`", + "Cannot shadow implicit global attribute `{place}` with declaration of type `{}`", declared_type.display(self.db()) )); diagnostic.info(format_args!("The global symbol `{}` must always have a type assignable to `{}`", - expr, + place, module_type_implicit_declaration.display(self.db()) )); } @@ -5920,7 +5977,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // Perform narrowing with applicable constraints between the current scope and the enclosing scope. - fn narrow_with_applicable_constraints( + fn narrow_place_with_applicable_constraints( &self, expr: &PlaceExpr, mut ty: Type<'db>, @@ -5929,11 +5986,69 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let db = self.db(); for (enclosing_scope_file_id, constraint_key) in constraint_keys { let use_def = self.index.use_def_map(*enclosing_scope_file_id); - let constraints = use_def.narrowing_constraints_at_use(*constraint_key); let place_table = self.index.place_table(*enclosing_scope_file_id); let place = place_table.place_id_by_expr(expr).unwrap(); - ty = constraints.narrow(db, ty, place); + match use_def.applicable_constraints( + *constraint_key, + *enclosing_scope_file_id, + expr, + self.index, + ) { + ApplicableConstraints::UnboundBinding(constraint) => { + ty = constraint.narrow(db, ty, place); + } + // Performs narrowing based on constrained bindings. + // This handling must be performed even if narrowing is attempted and failed using `infer_place_load`. + // The result of `infer_place_load` can be applied as is only when its boundness is `Bound`. + // For example, this handling is required in the following case: + // ```python + // class C: + // x: int | None = None + // c = C() + // # c.x: int | None = + // if c.x is None: + // c.x = 1 + // # else: c.x: int = + // # `c.x` is not definitely bound here + // reveal_type(c.x) # revealed: int + // ``` + ApplicableConstraints::ConstrainedBindings(bindings) => { + let reachability_constraints = bindings.reachability_constraints; + let predicates = bindings.predicates; + let mut union = UnionBuilder::new(db); + for binding in bindings { + let static_reachability = reachability_constraints.evaluate( + db, + predicates, + binding.reachability_constraint, + ); + if static_reachability.is_always_false() { + continue; + } + match binding.binding { + DefinitionState::Defined(definition) => { + let binding_ty = binding_type(db, definition); + union = union.add( + binding.narrowing_constraint.narrow(db, binding_ty, place), + ); + } + DefinitionState::Undefined | DefinitionState::Deleted => { + union = + union.add(binding.narrowing_constraint.narrow(db, ty, place)); + } + } + } + // If there are no visible bindings, the union becomes `Never`. + // Since an unbound binding is recorded even for an undefined place, + // this can only happen if the code is unreachable + // and therefore it is correct to set the result to `Never`. + let union = union.build(); + if union.is_assignable_to(db, ty) { + ty = union; + } + } + } } ty } @@ -5956,7 +6071,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // These are looked up as attributes on `types.ModuleType`. .or_fall_back_to(db, || { module_type_implicit_global_symbol(db, symbol_name).map_type(|ty| { - self.narrow_with_applicable_constraints(&expr, ty, &constraint_keys) + self.narrow_place_with_applicable_constraints(&expr, ty, &constraint_keys) }) }) // Not found in globals? Fallback to builtins @@ -6028,7 +6143,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - /// Infer the type of a place expression, assuming a load context. + /// Infer the type of a place expression from definitions, assuming a load context. + /// This method also returns the [`ConstraintKey`]s for each scope associated with `expr`, + /// which is used to narrow by condition rather than by assignment. fn infer_place_load( &self, expr: &PlaceExpr, @@ -6041,6 +6158,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut constraint_keys = vec![]; let (local_scope_place, use_id) = self.infer_local_place_load(expr, expr_ref); + if let Some(use_id) = use_id { + constraint_keys.push((file_scope_id, ConstraintKey::UseId(use_id))); + } let place = PlaceAndQualifiers::from(local_scope_place).or_fall_back_to(db, || { let has_bindings_in_this_scope = match place_table.place_by_expr(expr) { @@ -6081,7 +6201,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for root_expr in place_table.root_place_exprs(expr) { let mut expr_ref = expr_ref; - for _ in 0..(expr.sub_segments().len() - root_expr.sub_segments().len()) { + for _ in 0..(expr.sub_segments().len() - root_expr.expr.sub_segments().len()) { match expr_ref { ast::ExprRef::Attribute(attribute) => { expr_ref = ast::ExprRef::from(&attribute.value); @@ -6092,16 +6212,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { _ => unreachable!(), } } - let (parent_place, _use_id) = self.infer_local_place_load(root_expr, expr_ref); + let (parent_place, _use_id) = + self.infer_local_place_load(&root_expr.expr, expr_ref); if let Place::Type(_, _) = parent_place { return Place::Unbound.into(); } } - if let Some(use_id) = use_id { - constraint_keys.push((file_scope_id, ConstraintKey::UseId(use_id))); - } - // Walk up parent scopes looking for a possible enclosing scope that may have a // definition of this name visible to us (would be `LOAD_DEREF` at runtime.) // Note that we skip the scope containing the use that we are resolving, since we @@ -6144,15 +6261,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { { continue; } - return place_from_bindings(db, bindings) - .map_type(|ty| { - self.narrow_with_applicable_constraints( - expr, - ty, - &constraint_keys, - ) - }) - .into(); + let place = place_from_bindings(db, bindings).map_type(|ty| { + self.narrow_place_with_applicable_constraints( + expr, + ty, + &constraint_keys, + ) + }); + constraint_keys.push(( + enclosing_scope_file_id, + ConstraintKey::EagerNestedScope(file_scope_id), + )); + return place.into(); } // There are no visible bindings / constraint here. // Don't fall back to non-eager place resolution. @@ -6163,7 +6283,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { { if enclosing_root_place.is_bound() { if let Place::Type(_, _) = - place(db, enclosing_scope_id, enclosing_root_place).place + place(db, enclosing_scope_id, &enclosing_root_place.expr) + .place { return Place::Unbound.into(); } @@ -6190,7 +6311,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // isn't bound in that scope, we should get an unbound name, not continue // falling back to other scopes / globals / builtins. return place(db, enclosing_scope_id, expr).map_type(|ty| { - self.narrow_with_applicable_constraints(expr, ty, &constraint_keys) + self.narrow_place_with_applicable_constraints(expr, ty, &constraint_keys) }); } } @@ -6215,15 +6336,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { )); } EagerSnapshotResult::FoundBindings(bindings) => { - return place_from_bindings(db, bindings) - .map_type(|ty| { - self.narrow_with_applicable_constraints( - expr, - ty, - &constraint_keys, - ) - }) - .into(); + let place = place_from_bindings(db, bindings).map_type(|ty| { + self.narrow_place_with_applicable_constraints( + expr, + ty, + &constraint_keys, + ) + }); + constraint_keys.push(( + FileScopeId::global(), + ConstraintKey::EagerNestedScope(file_scope_id), + )); + return place.into(); } // There are no visible bindings / constraint here. EagerSnapshotResult::NotFound => { @@ -6238,7 +6362,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; explicit_global_symbol(db, self.file(), name).map_type(|ty| { - self.narrow_with_applicable_constraints(expr, ty, &constraint_keys) + self.narrow_place_with_applicable_constraints(expr, ty, &constraint_keys) }) }) }); @@ -6302,6 +6426,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } + fn narrow_expr_with_applicable_constraints<'r>( + &self, + target: impl Into>, + target_ty: Type<'db>, + constraint_keys: &[(FileScopeId, ConstraintKey)], + ) -> Type<'db> { + let target = target.into(); + + if let Ok(place_expr) = PlaceExpr::try_from(target) { + self.narrow_place_with_applicable_constraints(&place_expr, target_ty, constraint_keys) + } else { + target_ty + } + } + /// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context. fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { let ast::ExprAttribute { @@ -6314,27 +6453,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let value_type = self.infer_maybe_standalone_expression(value); let db = self.db(); + let mut constraint_keys = vec![]; - // If `attribute` is a valid reference, we attempt type narrowing by assignment. + let mut assigned_type = None; if let Ok(place_expr) = PlaceExpr::try_from(attribute) { - let member = value_type.class_member(db, attr.id.clone()); - // If the member is a data descriptor, the value most recently assigned - // to the attribute may not necessarily be obtained here. - if member - .place - .ignore_possibly_unbound() - .is_none_or(|ty| !ty.may_be_data_descriptor(db)) - { - let (resolved, _) = - self.infer_place_load(&place_expr, ast::ExprRef::Attribute(attribute)); - if let Place::Type(ty, Boundness::Bound) = resolved.place { - return ty; - } + let (resolved, keys) = + self.infer_place_load(&place_expr, ast::ExprRef::Attribute(attribute)); + constraint_keys.extend(keys); + if let Place::Type(ty, Boundness::Bound) = resolved.place { + assigned_type = Some(ty); } } - value_type + let resolved_type = value_type .member(db, &attr.id) + .map_type(|ty| self.narrow_expr_with_applicable_constraints(attribute, ty, &constraint_keys)) .unwrap_with_diagnostic(|lookup_error| match lookup_error { LookupError::Unbound(_) => { let report_unresolved_attribute = self.is_reachable(attribute); @@ -6394,7 +6527,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { type_when_bound } - }).inner_type() + }) + .inner_type(); + // Even if we can obtain the attribute type based on the assignments, we still perform default type inference + // (to report errors). + assigned_type.unwrap_or(resolved_type) } fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { @@ -7839,46 +7976,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { slice, ctx: _, } = subscript; - let db = self.db(); let value_ty = self.infer_expression(value); + let mut constraint_keys = vec![]; // If `value` is a valid reference, we attempt type narrowing by assignment. if !value_ty.is_unknown() { if let Ok(expr) = PlaceExpr::try_from(subscript) { - // Type narrowing based on assignment to a subscript expression is generally - // unsound, because arbitrary `__getitem__`/`__setitem__` methods on a class do not - // necessarily guarantee that the passed-in value for `__setitem__` is stored and - // can be retrieved unmodified via `__getitem__`. Therefore, we currently only - // perform assignment-based narrowing on a few built-in classes (`list`, `dict`, - // `bytesarray`, `TypedDict` and `collections` types) where we are confident that - // this kind of narrowing can be performed soundly. This is the same approach as - // pyright. TODO: Other standard library classes may also be considered safe. Also, - // subclasses of these safe classes that do not override `__getitem__/__setitem__` - // may be considered safe. - let safe_mutable_classes = [ - KnownClass::List.to_instance(db), - KnownClass::Dict.to_instance(db), - KnownClass::Bytearray.to_instance(db), - KnownClass::DefaultDict.to_instance(db), - SpecialFormType::ChainMap.instance_fallback(db), - SpecialFormType::Counter.instance_fallback(db), - SpecialFormType::Deque.instance_fallback(db), - SpecialFormType::OrderedDict.instance_fallback(db), - SpecialFormType::TypedDict.instance_fallback(db), - ]; - if safe_mutable_classes.iter().any(|safe_mutable_class| { - value_ty.is_equivalent_to(db, *safe_mutable_class) - || value_ty - .generic_origin(db) - .zip(safe_mutable_class.generic_origin(db)) - .is_some_and(|(l, r)| l == r) - }) { - let (place, _) = - self.infer_place_load(&expr, ast::ExprRef::Subscript(subscript)); - if let Place::Type(ty, Boundness::Bound) = place.place { - self.infer_expression(slice); - return ty; - } + let (place, keys) = + self.infer_place_load(&expr, ast::ExprRef::Subscript(subscript)); + constraint_keys.extend(keys); + if let Place::Type(ty, Boundness::Bound) = place.place { + // Even if we can obtain the subscript type based on the assignments, we still perform default type inference + // (to store the expression type and to report errors). + let slice_ty = self.infer_expression(slice); + self.infer_subscript_expression_types(value, value_ty, slice_ty); + return ty; } } } @@ -7908,7 +8020,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } let slice_ty = self.infer_expression(slice); - self.infer_subscript_expression_types(value, value_ty, slice_ty) + let result_ty = self.infer_subscript_expression_types(value, value_ty, slice_ty); + self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys) } fn infer_explicit_class_specialization( diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 9fa59dfcac..2538a37335 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1,7 +1,7 @@ use crate::Db; use crate::semantic_index::ast_ids::HasScopedExpressionId; use crate::semantic_index::expression::Expression; -use crate::semantic_index::place::{PlaceTable, ScopeId, ScopedPlaceId}; +use crate::semantic_index::place::{PlaceExpr, PlaceTable, ScopeId, ScopedPlaceId}; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, @@ -247,13 +247,12 @@ fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, } } -fn expr_name(expr: &ast::Expr) -> Option<&ast::name::Name> { +fn place_expr(expr: &ast::Expr) -> Option { match expr { - ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() { - ast::Expr::Name(ast::ExprName { id, .. }) => Some(id), - _ => None, - }, - ast::Expr::Name(ast::ExprName { id, .. }) => Some(id), + ast::Expr::Name(name) => Some(PlaceExpr::name(name.id.clone())), + ast::Expr::Attribute(attr) => PlaceExpr::try_from(attr).ok(), + ast::Expr::Subscript(subscript) => PlaceExpr::try_from(subscript).ok(), + ast::Expr::Named(named) => PlaceExpr::try_from(named.target.as_ref()).ok(), _ => None, } } @@ -314,7 +313,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { is_positive: bool, ) -> Option> { match expression_node { - ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, is_positive)), + ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { + self.evaluate_simple_expr(expression_node, is_positive) + } ast::Expr::Compare(expr_compare) => { self.evaluate_expr_compare(expr_compare, expression, is_positive) } @@ -374,27 +375,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } #[track_caller] - fn expect_expr_name_symbol(&self, symbol: &str) -> ScopedPlaceId { + fn expect_place(&self, place_expr: &PlaceExpr) -> ScopedPlaceId { self.places() - .place_id_by_name(symbol) - .expect("We should always have a symbol for every `Name` node") + .place_id_by_expr(place_expr) + .expect("We should always have a place for every `PlaceExpr`") } - fn evaluate_expr_name( + fn evaluate_simple_expr( &mut self, - expr_name: &ast::ExprName, + expr: &ast::Expr, is_positive: bool, - ) -> NarrowingConstraints<'db> { - let ast::ExprName { id, .. } = expr_name; + ) -> Option> { + let target = place_expr(expr)?; + let place = self.expect_place(&target); - let symbol = self.expect_expr_name_symbol(id); let ty = if is_positive { Type::AlwaysFalsy.negate(self.db) } else { Type::AlwaysTruthy.negate(self.db) }; - NarrowingConstraints::from_iter([(symbol, ty)]) + Some(NarrowingConstraints::from_iter([(place, ty)])) } fn evaluate_expr_named( @@ -402,11 +403,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expr_named: &ast::ExprNamed, is_positive: bool, ) -> Option> { - if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() { - Some(self.evaluate_expr_name(expr_name, is_positive)) - } else { - None - } + self.evaluate_simple_expr(&expr_named.target, is_positive) } fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { @@ -598,7 +595,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { matches!( expr, - ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_) + ast::Expr::Name(_) + | ast::Expr::Attribute(_) + | ast::Expr::Subscript(_) + | ast::Expr::Call(_) + | ast::Expr::Named(_) ) } @@ -644,13 +645,16 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { last_rhs_ty = Some(rhs_ty); match left { - ast::Expr::Name(_) | ast::Expr::Named(_) => { - if let Some(id) = expr_name(left) { - let symbol = self.expect_expr_name_symbol(id); + ast::Expr::Name(_) + | ast::Expr::Attribute(_) + | ast::Expr::Subscript(_) + | ast::Expr::Named(_) => { + if let Some(left) = place_expr(left) { let op = if is_positive { *op } else { op.negate() }; if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { - constraints.insert(symbol, ty); + let place = self.expect_place(&left); + constraints.insert(place, ty); } } } @@ -674,9 +678,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } }; - let id = match &**args { - [first] => match expr_name(first) { - Some(id) => id, + let target = match &**args { + [first] => match place_expr(first) { + Some(target) => target, None => continue, }, _ => continue, @@ -699,9 +703,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .into_class_literal() .is_some_and(|c| c.is_known(self.db, KnownClass::Type)) { - let symbol = self.expect_expr_name_symbol(id); + let place = self.expect_place(&target); constraints.insert( - symbol, + place, Type::instance(self.db, rhs_class.unknown_specialization(self.db)), ); } @@ -754,9 +758,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let [first_arg, second_arg] = &*expr_call.arguments.args else { return None; }; - let first_arg = expr_name(first_arg)?; + let first_arg = place_expr(first_arg)?; let function = function_type.known(self.db)?; - let symbol = self.expect_expr_name_symbol(first_arg); + let place = self.expect_place(&first_arg); if function == KnownFunction::HasAttr { let attr = inference @@ -774,7 +778,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ); return Some(NarrowingConstraints::from_iter([( - symbol, + place, constraint.negate_if(self.db, !is_positive), )])); } @@ -788,7 +792,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .generate_constraint(self.db, class_info_ty) .map(|constraint| { NarrowingConstraints::from_iter([( - symbol, + place, constraint.negate_if(self.db, !is_positive), )]) }) @@ -814,15 +818,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, singleton: ast::Singleton, ) -> Option> { - let symbol = self - .expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id); + let subject = place_expr(subject.node_ref(self.db, self.module))?; + let place = self.expect_place(&subject); let ty = match singleton { ast::Singleton::None => Type::none(self.db), ast::Singleton::True => Type::BooleanLiteral(true), ast::Singleton::False => Type::BooleanLiteral(false), }; - Some(NarrowingConstraints::from_iter([(symbol, ty)])) + Some(NarrowingConstraints::from_iter([(place, ty)])) } fn evaluate_match_pattern_class( @@ -830,11 +834,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, cls: Expression<'db>, ) -> Option> { - let symbol = self - .expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id); + let subject = place_expr(subject.node_ref(self.db, self.module))?; + let place = self.expect_place(&subject); + let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?; - Some(NarrowingConstraints::from_iter([(symbol, ty)])) + Some(NarrowingConstraints::from_iter([(place, ty)])) } fn evaluate_match_pattern_value( @@ -842,10 +847,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, value: Expression<'db>, ) -> Option> { - let symbol = self - .expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id); + let subject = place_expr(subject.node_ref(self.db, self.module))?; + let place = self.expect_place(&subject); + let ty = infer_same_file_expression_type(self.db, value, self.module); - Some(NarrowingConstraints::from_iter([(symbol, ty)])) + Some(NarrowingConstraints::from_iter([(place, ty)])) } fn evaluate_match_pattern_or(