[ty] Improve `@override`, `@final` and Liskov checks in cases where there are multiple reachable definitions (#21767)

This commit is contained in:
Alex Waygood 2025-12-03 12:51:36 +00:00 committed by GitHub
parent 5756b3809c
commit cd079bd92e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 688 additions and 131 deletions

View File

@ -327,9 +327,7 @@ impl<'ast> MembersInScope<'ast> {
.members_in_scope_at(node)
.into_iter()
.map(|(name, memberdef)| {
let Some(def) = memberdef.definition else {
return (name, MemberInScope::other(memberdef.ty));
};
let def = memberdef.first_reachable_definition;
let kind = match *def.kind(db) {
DefinitionKind::Import(ref kind) => {
MemberImportKind::Imported(AstImportKind::Import(kind.import(parsed)))
@ -1891,13 +1889,13 @@ else:
"#);
assert_snapshot!(
test.import_from("foo", "MAGIC"), @r#"
import foo
from foo import MAGIC
if os.getenv("WHATEVER"):
from foo import MAGIC
else:
from bar import MAGIC
(foo.MAGIC)
(MAGIC)
"#);
}
@ -2108,13 +2106,13 @@ except ImportError:
");
assert_snapshot!(
test.import_from("foo", "MAGIC"), @r"
import foo
from foo import MAGIC
try:
from foo import MAGIC
except ImportError:
from bar import MAGIC
(foo.MAGIC)
(MAGIC)
");
}

View File

@ -488,11 +488,110 @@ class C(A):
pass
if coinflip():
def method2(self) -> None: ... # TODO: should emit [override-of-final-method]
def method2(self) -> None: ... # error: [override-of-final-method]
else:
def method2(self) -> None: ... # TODO: should emit [override-of-final-method]
def method2(self) -> None: ...
if coinflip():
def method3(self) -> None: ... # error: [override-of-final-method]
def method4(self) -> None: ... # error: [override-of-final-method]
# TODO: we should emit Liskov violations here too:
if coinflip():
method4 = 42 # error: [override-of-final-method]
else:
method4 = 56
```
## Definitions in statically known branches
```toml
[environment]
python-version = "3.10"
```
```py
import sys
from typing_extensions import final
class Parent:
if sys.version_info >= (3, 10):
@final
def foo(self) -> None: ...
@final
def foooo(self) -> None: ...
@final
def baaaaar(self) -> None: ...
else:
@final
def bar(self) -> None: ...
@final
def baz(self) -> None: ...
@final
def spam(self) -> None: ...
class Child(Parent):
def foo(self) -> None: ... # error: [override-of-final-method]
# The declaration on `Parent` is not reachable,
# so this is fine
def bar(self) -> None: ...
if sys.version_info >= (3, 10):
def foooo(self) -> None: ... # error: [override-of-final-method]
def baz(self) -> None: ...
else:
# Fine because this doesn't override any reachable definitions
def foooo(self) -> None: ...
# There are `@final` definitions being overridden here,
# but the definitions that override them are unreachable
def spam(self) -> None: ...
def baaaaar(self) -> None: ...
```
## Overloads in statically-known branches in stub files
<!-- snapshot-diagnostics -->
```toml
[environment]
python-version = "3.10"
```
```pyi
import sys
from typing_extensions import overload, final
class Foo:
if sys.version_info >= (3, 10):
@overload
@final
def method(self, x: int) -> int: ...
else:
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
if sys.version_info >= (3, 10):
@overload
def method2(self, x: int) -> int: ...
else:
@overload
@final
def method2(self, x: int) -> int: ...
@overload
def method2(self, x: str) -> str: ...
class Bar(Foo):
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ... # error: [override-of-final-method]
# This is fine: the only overload that is marked `@final`
# is in a statically-unreachable branch
@overload
def method2(self, x: int) -> int: ...
@overload
def method2(self, x: str) -> str: ...
```

View File

@ -583,3 +583,17 @@ class GoodChild2(Parent):
@staticmethod
def static_method(x: object) -> bool: ...
```
## Definitely bound members with no reachable definitions(!)
We don't emit a Liskov-violation diagnostic here, but if you're writing code like this, you probably
have bigger problems:
```py
from __future__ import annotations
class MaybeEqWhile:
while ...:
def __eq__(self, other: MaybeEqWhile) -> bool:
return True
```

View File

@ -610,3 +610,24 @@ class Child(Base):
# This is fine - Child is not directly a NamedTuple
_asdict = 42
```
## Edge case: multiple reachable definitions with distinct issues
<!-- snapshot-diagnostics -->
```py
from typing import NamedTuple
def coinflip() -> bool:
return True
class Foo(NamedTuple):
if coinflip():
_asdict: bool # error: [invalid-named-tuple] "NamedTuple field `_asdict` cannot start with an underscore"
else:
# TODO: there should only be one diagnostic here...
#
# error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
# error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
_asdict = True
```

View File

@ -220,6 +220,178 @@ class Foo:
def bar(self): ... # error: [invalid-explicit-override]
```
## Possibly-unbound definitions
```py
from typing_extensions import override
def coinflip() -> bool:
return False
class Parent:
if coinflip():
def method1(self) -> None: ...
def method2(self) -> None: ...
if coinflip():
def method3(self) -> None: ...
def method4(self) -> None: ...
else:
def method3(self) -> None: ...
def method4(self) -> None: ...
def method5(self) -> None: ...
def method6(self) -> None: ...
class Child(Parent):
@override
def method1(self) -> None: ...
@override
def method2(self) -> None: ...
if coinflip():
@override
def method3(self) -> None: ...
if coinflip():
@override
def method4(self) -> None: ...
else:
@override
def method4(self) -> None: ...
if coinflip():
@override
def method5(self) -> None: ...
if coinflip():
@override
def method6(self) -> None: ...
else:
@override
def method6(self) -> None: ...
if coinflip():
@override
def method7(self) -> None: ... # error: [invalid-explicit-override]
if coinflip():
@override
def method8(self) -> None: ... # error: [invalid-explicit-override]
else:
@override
def method8(self) -> None: ...
```
## Multiple reachable definitions, only one of which is decorated with `@override`
The diagnostic should point to the first definition decorated with `@override`, which may not
necessarily be the first definition of the symbol overall:
`runtime.py`:
```py
from typing_extensions import override, overload
def coinflip() -> bool:
return True
class Foo:
if coinflip():
def method(self, x): ...
elif coinflip():
@overload
def method(self, x: str) -> str: ...
@overload
def method(self, x: int) -> int: ...
@override
def method(self, x: str | int) -> str | int: # error: [invalid-explicit-override]
return x
elif coinflip():
@override
def method(self, x): ...
```
stub.pyi\`:
```pyi
from typing_extensions import override, overload
def coinflip() -> bool:
return True
class Foo:
if coinflip():
def method(self, x): ...
elif coinflip():
@overload
@override
def method(self, x: str) -> str: ... # error: [invalid-explicit-override]
@overload
def method(self, x: int) -> int: ...
if coinflip():
def method2(self, x): ...
elif coinflip():
@overload
@override
def method2(self, x: str) -> str: ...
@overload
def method2(self, x: int) -> int: ...
else:
# TODO: not sure why this is being emitted on this line rather than on
# the first overload in the `elif` block? Ideally it would be emitted
# on the first reachable definition, but perhaps this is due to the way
# name lookups are deferred in stub files...? -- AW
@override
def method2(self, x): ... # error: [invalid-explicit-override]
```
## Definitions in statically known branches
```toml
[environment]
python-version = "3.10"
```
```py
import sys
from typing_extensions import override, overload
class Parent:
if sys.version_info >= (3, 10):
def foo(self) -> None: ...
def foooo(self) -> None: ...
else:
def bar(self) -> None: ...
def baz(self) -> None: ...
def spam(self) -> None: ...
class Child(Parent):
@override
def foo(self) -> None: ...
# The declaration on `Parent` is not reachable,
# so this is an error
@override
def bar(self) -> None: ... # error: [invalid-explicit-override]
if sys.version_info >= (3, 10):
@override
def foooo(self) -> None: ...
@override
def baz(self) -> None: ... # error: [invalid-explicit-override]
else:
# This doesn't override any reachable definitions,
# but the subclass definition also isn't a reachable definition
# from the end of the scope with the given configuration,
# so it's not flagged
@override
def foooo(self) -> None: ...
@override
def spam(self) -> None: ...
```
## Overloads
The typing spec states that for an overloaded method, `@override` should only be applied to the
@ -293,6 +465,39 @@ class Spam:
def baz(self, x: int) -> int: ...
```
## Overloads in statically-known branches in stub files
```toml
[environment]
python-version = "3.10"
```
```pyi
import sys
from typing_extensions import overload, override
class Foo:
if sys.version_info >= (3, 10):
@overload
@override
def method(self, x: int) -> int: ... # error: [invalid-explicit-override]
else:
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
if sys.version_info >= (3, 10):
@overload
def method2(self, x: int) -> int: ...
else:
@overload
@override
def method2(self, x: int) -> int: ...
@overload
def method2(self, x: str) -> str: ...
```
## Classes inheriting from `Any`
```py

View File

@ -65,13 +65,18 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/final.md
51 | pass
52 |
53 | if coinflip():
54 | def method2(self) -> None: ... # TODO: should emit [override-of-final-method]
54 | def method2(self) -> None: ... # error: [override-of-final-method]
55 | else:
56 | def method2(self) -> None: ... # TODO: should emit [override-of-final-method]
56 | def method2(self) -> None: ...
57 |
58 | if coinflip():
59 | def method3(self) -> None: ... # error: [override-of-final-method]
60 | def method4(self) -> None: ... # error: [override-of-final-method]
60 |
61 | # TODO: we should emit Liskov violations here too:
62 | if coinflip():
63 | method4 = 42 # error: [override-of-final-method]
64 | else:
65 | method4 = 56
```
# Diagnostics
@ -240,6 +245,33 @@ info: rule `override-of-final-method` is enabled by default
```
```
error[override-of-final-method]: Cannot override `A.method2`
--> src/mdtest_snippet.py:54:13
|
53 | if coinflip():
54 | def method2(self) -> None: ... # error: [override-of-final-method]
| ^^^^^^^ Overrides a definition from superclass `A`
55 | else:
56 | def method2(self) -> None: ...
|
info: `A.method2` is decorated with `@final`, forbidding overrides
--> src/mdtest_snippet.py:16:9
|
14 | def method2(self) -> None: ...
15 | else:
16 | @final
| ------
17 | def method2(self) -> None: ...
| ------- `A.method2` defined here
18 |
19 | if coinflip():
|
help: Remove the override of `method2`
info: rule `override-of-final-method` is enabled by default
```
```
error[override-of-final-method]: Cannot override `A.method3`
--> src/mdtest_snippet.py:59:13
@ -247,7 +279,8 @@ error[override-of-final-method]: Cannot override `A.method3`
58 | if coinflip():
59 | def method3(self) -> None: ... # error: [override-of-final-method]
| ^^^^^^^ Overrides a definition from superclass `A`
60 | def method4(self) -> None: ... # error: [override-of-final-method]
60 |
61 | # TODO: we should emit Liskov violations here too:
|
info: `A.method3` is decorated with `@final`, forbidding overrides
--> src/mdtest_snippet.py:20:9
@ -267,12 +300,14 @@ info: rule `override-of-final-method` is enabled by default
```
error[override-of-final-method]: Cannot override `A.method4`
--> src/mdtest_snippet.py:60:13
--> src/mdtest_snippet.py:63:9
|
58 | if coinflip():
59 | def method3(self) -> None: ... # error: [override-of-final-method]
60 | def method4(self) -> None: ... # error: [override-of-final-method]
61 | # TODO: we should emit Liskov violations here too:
62 | if coinflip():
63 | method4 = 42 # error: [override-of-final-method]
| ^^^^^^^ Overrides a definition from superclass `A`
64 | else:
65 | method4 = 56
|
info: `A.method4` is decorated with `@final`, forbidding overrides
--> src/mdtest_snippet.py:29:9

View File

@ -0,0 +1,94 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: final.md - Tests for the `@typing(_extensions).final` decorator - Overloads in statically-known branches in stub files
mdtest path: crates/ty_python_semantic/resources/mdtest/final.md
---
# Python source files
## mdtest_snippet.pyi
```
1 | import sys
2 | from typing_extensions import overload, final
3 |
4 | class Foo:
5 | if sys.version_info >= (3, 10):
6 | @overload
7 | @final
8 | def method(self, x: int) -> int: ...
9 | else:
10 | @overload
11 | def method(self, x: int) -> int: ...
12 | @overload
13 | def method(self, x: str) -> str: ...
14 |
15 | if sys.version_info >= (3, 10):
16 | @overload
17 | def method2(self, x: int) -> int: ...
18 | else:
19 | @overload
20 | @final
21 | def method2(self, x: int) -> int: ...
22 | @overload
23 | def method2(self, x: str) -> str: ...
24 |
25 | class Bar(Foo):
26 | @overload
27 | def method(self, x: int) -> int: ...
28 | @overload
29 | def method(self, x: str) -> str: ... # error: [override-of-final-method]
30 |
31 | # This is fine: the only overload that is marked `@final`
32 | # is in a statically-unreachable branch
33 | @overload
34 | def method2(self, x: int) -> int: ...
35 | @overload
36 | def method2(self, x: str) -> str: ...
```
# Diagnostics
```
error[override-of-final-method]: Cannot override `Foo.method`
--> src/mdtest_snippet.pyi:29:9
|
27 | def method(self, x: int) -> int: ...
28 | @overload
29 | def method(self, x: str) -> str: ... # error: [override-of-final-method]
| ^^^^^^ Overrides a definition from superclass `Foo`
30 |
31 | # This is fine: the only overload that is marked `@final`
|
info: `Foo.method` is decorated with `@final`, forbidding overrides
--> src/mdtest_snippet.pyi:7:9
|
5 | if sys.version_info >= (3, 10):
6 | @overload
7 | @final
| ------
8 | def method(self, x: int) -> int: ...
| ------ `Foo.method` defined here
9 | else:
10 | @overload
|
help: Remove all overloads for `method`
info: rule `override-of-final-method` is enabled by default
23 | def method2(self, x: str) -> str: ...
24 |
25 | class Bar(Foo):
- @overload
- def method(self, x: int) -> int: ...
- @overload
- def method(self, x: str) -> str: ... # error: [override-of-final-method]
26 +
27 + # error: [override-of-final-method]
28 |
29 | # This is fine: the only overload that is marked `@final`
30 | # is in a statically-unreachable branch
note: This is an unsafe fix and may change runtime behavior
```

View File

@ -0,0 +1,74 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: named_tuple.md - `NamedTuple` - Edge case: multiple reachable definitions with distinct issues
mdtest path: crates/ty_python_semantic/resources/mdtest/named_tuple.md
---
# Python source files
## mdtest_snippet.py
```
1 | from typing import NamedTuple
2 |
3 | def coinflip() -> bool:
4 | return True
5 |
6 | class Foo(NamedTuple):
7 | if coinflip():
8 | _asdict: bool # error: [invalid-named-tuple] "NamedTuple field `_asdict` cannot start with an underscore"
9 | else:
10 | # TODO: there should only be one diagnostic here...
11 | #
12 | # error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
13 | # error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
14 | _asdict = True
```
# Diagnostics
```
error[invalid-named-tuple]: NamedTuple field name cannot start with an underscore
--> src/mdtest_snippet.py:8:9
|
6 | class Foo(NamedTuple):
7 | if coinflip():
8 | _asdict: bool # error: [invalid-named-tuple] "NamedTuple field `_asdict` cannot start with an underscore"
| ^^^^^^^^^^^^^ Class definition will raise `TypeError` at runtime due to this field
9 | else:
10 | # TODO: there should only be one diagnostic here...
|
info: rule `invalid-named-tuple` is enabled by default
```
```
error[invalid-named-tuple]: Cannot overwrite NamedTuple attribute `_asdict`
--> src/mdtest_snippet.py:14:9
|
12 | # error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
13 | # error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
14 | _asdict = True
| ^^^^^^^
|
info: This will cause the class creation to fail at runtime
info: rule `invalid-named-tuple` is enabled by default
```
```
error[invalid-named-tuple]: Cannot overwrite NamedTuple attribute `_asdict`
--> src/mdtest_snippet.py:14:9
|
12 | # error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
13 | # error: [invalid-named-tuple] "Cannot overwrite NamedTuple attribute `_asdict`"
14 | _asdict = True
| ^^^^^^^
|
info: This will cause the class creation to fail at runtime
info: rule `invalid-named-tuple` is enabled by default
```

View File

@ -459,7 +459,7 @@ fn core_module_scope(db: &dyn Db, core_module: KnownModule) -> Option<ScopeId<'_
pub(super) fn place_from_bindings<'db>(
db: &'db dyn Db,
bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>,
) -> Place<'db> {
) -> PlaceWithDefinition<'db> {
place_from_bindings_impl(db, bindings_with_constraints, RequiresExplicitReExport::No)
}
@ -487,20 +487,21 @@ type DeclaredTypeAndConflictingTypes<'db> = (
pub(crate) struct PlaceFromDeclarationsResult<'db> {
place_and_quals: PlaceAndQualifiers<'db>,
conflicting_types: Option<Box<indexmap::set::Slice<Type<'db>>>>,
/// Contains `Some(declaration)` if the declared type originates from exactly one declaration.
/// Contains the first reachable declaration for this place, if any.
/// This field is used for backreferences in diagnostics.
pub(crate) single_declaration: Option<Definition<'db>>,
pub(crate) first_declaration: Option<Definition<'db>>,
}
impl<'db> PlaceFromDeclarationsResult<'db> {
fn conflict(
place_and_quals: PlaceAndQualifiers<'db>,
conflicting_types: Box<indexmap::set::Slice<Type<'db>>>,
first_declaration: Option<Definition<'db>>,
) -> Self {
PlaceFromDeclarationsResult {
place_and_quals,
conflicting_types: Some(conflicting_types),
single_declaration: None,
first_declaration,
}
}
@ -798,6 +799,7 @@ pub(crate) fn place_by_id<'db>(
if let Some(qualifiers) = declared.is_bare_final() {
let bindings = all_considered_bindings();
return place_from_bindings_impl(db, bindings, requires_explicit_reexport)
.place
.with_qualifiers(qualifiers);
}
@ -809,7 +811,7 @@ pub(crate) fn place_by_id<'db>(
qualifiers,
} if qualifiers.contains(TypeQualifiers::CLASS_VAR) => {
let bindings = all_considered_bindings();
match place_from_bindings_impl(db, bindings, requires_explicit_reexport) {
match place_from_bindings_impl(db, bindings, requires_explicit_reexport).place {
Place::Defined(inferred, origin, boundness) => Place::Defined(
UnionType::from_elements(db, [Type::unknown(), inferred]),
origin,
@ -835,7 +837,7 @@ pub(crate) fn place_by_id<'db>(
let boundness_analysis = bindings.boundness_analysis;
let inferred = place_from_bindings_impl(db, bindings, requires_explicit_reexport);
let place = match inferred {
let place = match inferred.place {
// Place is possibly undeclared and definitely unbound
Place::Undefined => {
// TODO: We probably don't want to report `AlwaysDefined` here. This requires a bit of
@ -864,7 +866,8 @@ pub(crate) fn place_by_id<'db>(
} => {
let bindings = all_considered_bindings();
let boundness_analysis = bindings.boundness_analysis;
let mut inferred = place_from_bindings_impl(db, bindings, requires_explicit_reexport);
let mut inferred =
place_from_bindings_impl(db, bindings, requires_explicit_reexport).place;
if boundness_analysis == BoundnessAnalysis::AssumeBound {
if let Place::Defined(ty, origin, Definedness::PossiblyUndefined) = inferred {
@ -1010,7 +1013,7 @@ fn place_from_bindings_impl<'db>(
db: &'db dyn Db,
bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>,
requires_explicit_reexport: RequiresExplicitReExport,
) -> Place<'db> {
) -> PlaceWithDefinition<'db> {
let predicates = bindings_with_constraints.predicates;
let reachability_constraints = bindings_with_constraints.reachability_constraints;
let boundness_analysis = bindings_with_constraints.boundness_analysis;
@ -1039,6 +1042,8 @@ fn place_from_bindings_impl<'db>(
})
};
let mut first_definition = None;
let mut types = bindings_with_constraints.filter_map(
|BindingWithConstraints {
binding,
@ -1119,12 +1124,13 @@ fn place_from_bindings_impl<'db>(
return None;
}
first_definition.get_or_insert(binding);
let binding_ty = binding_type(db, binding);
Some(narrowing_constraint.narrow(db, binding_ty, binding.place(db)))
},
);
if let Some(first) = types.next() {
let place = if let Some(first) = types.next() {
let ty = if let Some(second) = types.next() {
let mut builder = PublicTypeBuilder::new(db);
builder.add(first);
@ -1161,9 +1167,19 @@ fn place_from_bindings_impl<'db>(
}
} else {
Place::Undefined
};
PlaceWithDefinition {
place,
first_definition,
}
}
pub(super) struct PlaceWithDefinition<'db> {
pub(super) place: Place<'db>,
pub(super) first_definition: Option<Definition<'db>>,
}
/// Accumulates types from multiple bindings or declarations, and eventually builds a
/// union type from them.
///
@ -1294,7 +1310,6 @@ fn place_from_declarations_impl<'db>(
let boundness_analysis = declarations.boundness_analysis;
let mut declarations = declarations.peekable();
let mut first_declaration = None;
let mut exactly_one_declaration = false;
let is_non_exported = |declaration: Definition<'db>| {
requires_explicit_reexport.is_yes() && !is_reexported(db, declaration)
@ -1325,12 +1340,7 @@ fn place_from_declarations_impl<'db>(
return None;
}
if first_declaration.is_none() {
first_declaration = Some(declaration);
exactly_one_declaration = true;
} else {
exactly_one_declaration = false;
}
first_declaration.get_or_insert(declaration);
let static_reachability =
reachability_constraints.evaluate(db, predicates, reachability_constraint);
@ -1387,19 +1397,19 @@ fn place_from_declarations_impl<'db>(
.with_qualifiers(declared.qualifiers());
if let Some(conflicting) = conflicting {
PlaceFromDeclarationsResult::conflict(place_and_quals, conflicting)
PlaceFromDeclarationsResult::conflict(place_and_quals, conflicting, first_declaration)
} else {
PlaceFromDeclarationsResult {
place_and_quals,
conflicting_types: None,
single_declaration: first_declaration.filter(|_| exactly_one_declaration),
first_declaration,
}
}
} else {
PlaceFromDeclarationsResult {
place_and_quals: Place::Undefined.into(),
conflicting_types: None,
single_declaration: None,
first_declaration: None,
}
}
}

View File

@ -82,7 +82,7 @@ impl<'db> SemanticModel<'db> {
memberdef.member.name,
MemberDefinition {
ty: memberdef.member.ty,
definition: memberdef.definition,
first_reachable_definition: memberdef.first_reachable_definition,
},
);
}
@ -328,11 +328,11 @@ impl<'db> SemanticModel<'db> {
}
}
/// The type and definition (if available) of a symbol.
/// The type and definition of a symbol.
#[derive(Clone, Debug)]
pub struct MemberDefinition<'db> {
pub ty: Type<'db>,
pub definition: Option<Definition<'db>>,
pub first_reachable_definition: Definition<'db>,
}
/// A classification of symbol names.

View File

@ -1375,9 +1375,9 @@ pub(crate) struct Field<'db> {
pub(crate) declared_ty: Type<'db>,
/// Kind-specific metadata for this field
pub(crate) kind: FieldKind<'db>,
/// The original declaration of this field, if there is exactly one.
/// The first declaration of this field.
/// This field is used for backreferences in diagnostics.
pub(crate) single_declaration: Option<Definition<'db>>,
pub(crate) first_declaration: Option<Definition<'db>>,
}
impl Field<'_> {
@ -3039,7 +3039,7 @@ impl<'db> ClassLiteral<'db> {
let symbol = table.symbol(symbol_id);
let result = place_from_declarations(db, declarations.clone());
let single_declaration = result.single_declaration;
let first_declaration = result.first_declaration;
let attr = result.ignore_conflicting_declarations();
if attr.is_class_var() {
continue;
@ -3047,7 +3047,9 @@ impl<'db> ClassLiteral<'db> {
if let Some(attr_ty) = attr.place.ignore_possibly_undefined() {
let bindings = use_def.end_of_scope_symbol_bindings(symbol_id);
let mut default_ty = place_from_bindings(db, bindings).ignore_possibly_undefined();
let mut default_ty = place_from_bindings(db, bindings)
.place
.ignore_possibly_undefined();
default_ty =
default_ty.map(|ty| ty.apply_optional_specialization(db, specialization));
@ -3105,7 +3107,7 @@ impl<'db> ClassLiteral<'db> {
let mut field = Field {
declared_ty: attr_ty.apply_optional_specialization(db, specialization),
kind,
single_declaration,
first_declaration,
};
// Check if this is a KW_ONLY sentinel and mark subsequent fields as keyword-only
@ -3588,7 +3590,7 @@ impl<'db> ClassLiteral<'db> {
// The attribute is declared in the class body.
let bindings = use_def.end_of_scope_symbol_bindings(symbol_id);
let inferred = place_from_bindings(db, bindings);
let inferred = place_from_bindings(db, bindings).place;
let has_binding = !inferred.is_undefined();
if has_binding {
@ -3831,7 +3833,9 @@ impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
(symbol_id, place_and_qual)
})
.chain(use_def_map.all_end_of_scope_symbol_bindings().map(
|(symbol_id, bindings)| (symbol_id, place_from_bindings(db, bindings).into()),
|(symbol_id, bindings)| {
(symbol_id, place_from_bindings(db, bindings).place.into())
},
))
.filter_map(|(symbol_id, place_and_qual)| {
if let Some(name) = table.place(symbol_id).as_symbol().map(Symbol::name) {

View File

@ -77,7 +77,7 @@ pub(crate) fn enum_metadata<'db>(
let ignored_names: Option<Vec<&str>> = if let Some(ignore) = table.symbol_id("_ignore_") {
let ignore_bindings = use_def_map.all_reachable_symbol_bindings(ignore);
let ignore_place = place_from_bindings(db, ignore_bindings);
let ignore_place = place_from_bindings(db, ignore_bindings).place;
match ignore_place {
Place::Defined(Type::StringLiteral(ignored_names), _, _) => {
@ -111,7 +111,7 @@ pub(crate) fn enum_metadata<'db>(
return None;
}
let inferred = place_from_bindings(db, bindings);
let inferred = place_from_bindings(db, bindings).place;
let value_ty = match inferred {
Place::Undefined => {

View File

@ -373,7 +373,7 @@ impl<'db> OverloadLiteral<'db> {
.scoped_use_id(db, scope);
let Place::Defined(Type::FunctionLiteral(previous_type), _, Definedness::AlwaysDefined) =
place_from_bindings(db, use_def.bindings_at_use(use_id))
place_from_bindings(db, use_def.bindings_at_use(use_id)).place
else {
return None;
};

View File

@ -634,7 +634,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&self.context,
class,
&field_name,
field.single_declaration,
field.first_declaration,
);
}
@ -645,13 +645,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
) {
field_with_default_encountered =
Some((field_name, field.single_declaration));
Some((field_name, field.first_declaration));
} else if let Some(field_with_default) = field_with_default_encountered.as_ref()
{
report_namedtuple_field_without_default_after_field_with_default(
&self.context,
class,
(&field_name, field.single_declaration),
(&field_name, field.first_declaration),
field_with_default,
);
}
@ -1034,6 +1034,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.db(),
use_def.end_of_scope_symbol_bindings(place.as_symbol().unwrap()),
)
.place
{
if function.file(self.db()) != self.file() {
// If the function is not in this file, we don't need to check it.
@ -1727,6 +1728,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let prior_bindings = use_def.bindings_at_definition(declaration);
// unbound_ty is Never because for this check we don't care about unbound
let inferred_ty = place_from_bindings(self.db(), prior_bindings)
.place
.with_qualifiers(TypeQualifiers::empty())
.or_fall_back_to(self.db(), || {
// Fallback to bindings declared on `types.ModuleType` if it's a global symbol
@ -8673,7 +8675,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// If we're inferring types of deferred expressions, look them up from end-of-scope.
if self.is_deferred() {
let place = if let Some(place_id) = place_table.place_id(expr) {
place_from_bindings(db, use_def.all_reachable_bindings(place_id))
place_from_bindings(db, use_def.all_reachable_bindings(place_id)).place
} else {
assert!(
self.deferred_state.in_string_annotation(),
@ -8691,7 +8693,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
let use_id = expr_ref.scoped_use_id(db, scope);
let place = place_from_bindings(db, use_def.bindings_at_use(use_id));
let place = place_from_bindings(db, use_def.bindings_at_use(use_id)).place;
(place, Some(use_id))
}
@ -8832,7 +8834,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
EnclosingSnapshotResult::FoundBindings(bindings) => {
let place = place_from_bindings(db, bindings).map_type(|ty| {
let place = place_from_bindings(db, bindings).place.map_type(|ty| {
self.narrow_place_with_applicable_constraints(
place_expr,
ty,
@ -8952,7 +8954,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return Place::Undefined.into();
}
EnclosingSnapshotResult::FoundBindings(bindings) => {
let place = place_from_bindings(db, bindings).map_type(|ty| {
let place =
place_from_bindings(db, bindings).place.map_type(|ty| {
self.narrow_place_with_applicable_constraints(
place_expr,
ty,

View File

@ -12,7 +12,9 @@ use rustc_hash::FxHashSet;
use crate::{
Db, NameKind,
place::{Place, imported_symbol, place_from_bindings, place_from_declarations},
place::{
Place, PlaceWithDefinition, imported_symbol, place_from_bindings, place_from_declarations,
},
semantic_index::{
attribute_scopes, definition::Definition, global_scope, place_table, scope::ScopeId,
semantic_index, use_def_map,
@ -35,47 +37,39 @@ pub(crate) fn all_members_of_scope<'db>(
.all_end_of_scope_symbol_declarations()
.filter_map(move |(symbol_id, declarations)| {
let place_result = place_from_declarations(db, declarations);
let definition = place_result.single_declaration;
place_result
let first_reachable_definition = place_result.first_declaration?;
let ty = place_result
.ignore_conflicting_declarations()
.place
.ignore_possibly_undefined()
.map(|ty| {
.ignore_possibly_undefined()?;
let symbol = table.symbol(symbol_id);
let member = Member {
name: symbol.name().clone(),
ty,
};
MemberWithDefinition { member, definition }
Some(MemberWithDefinition {
member,
first_reachable_definition,
})
})
.chain(use_def_map.all_end_of_scope_symbol_bindings().filter_map(
move |(symbol_id, bindings)| {
// It's not clear to AG how to using a bindings
// iterator here to get the correct definition for
// this binding. Below, we look through all bindings
// with a definition and only take one if there is
// exactly one. I don't think this can be wrong, but
// it's probably omitting definitions in some cases.
let mut definition = None;
for binding in bindings.clone() {
if let Some(def) = binding.binding.definition() {
if definition.is_some() {
definition = None;
break;
}
definition = Some(def);
}
}
place_from_bindings(db, bindings)
.ignore_possibly_undefined()
.map(|ty| {
let PlaceWithDefinition {
place,
first_definition,
} = place_from_bindings(db, bindings);
let first_reachable_definition = first_definition?;
let ty = place.ignore_possibly_undefined()?;
let symbol = table.symbol(symbol_id);
let member = Member {
name: symbol.name().clone(),
ty,
};
MemberWithDefinition { member, definition }
Some(MemberWithDefinition {
member,
first_reachable_definition,
})
},
))
@ -457,11 +451,11 @@ impl<'db> AllMembers<'db> {
}
}
/// A member of a type or scope, with an optional definition.
/// A member of a type or scope, with the first reachable definition of that member.
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct MemberWithDefinition<'db> {
pub member: Member<'db>,
pub definition: Option<Definition<'db>>,
pub first_reachable_definition: Definition<'db>,
}
/// A member of a type or scope.

View File

@ -75,7 +75,7 @@ pub(super) fn class_member<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str
// Otherwise, we need to check if the symbol has bindings
let use_def = use_def_map(db, scope);
let bindings = use_def.end_of_scope_symbol_bindings(symbol_id);
let inferred = place_from_bindings(db, bindings);
let inferred = place_from_bindings(db, bindings).place;
// TODO: we should not need to calculate inferred type second time. This is a temporary
// solution until the notion of Boundness and Declaredness is split. See #16036, #16264

View File

@ -12,8 +12,8 @@ use crate::{
lint::LintId,
place::Place,
semantic_index::{
definition::DefinitionKind, place_table, scope::ScopeId, symbol::ScopedSymbolId,
use_def_map,
definition::DefinitionKind, place::ScopedPlaceId, place_table, scope::ScopeId,
symbol::ScopedSymbolId, use_def_map,
},
types::{
ClassBase, ClassLiteral, ClassType, KnownClass, Type,
@ -53,10 +53,11 @@ pub(super) fn check_class<'db>(context: &InferContext<'db, '_>, class: ClassLite
}
let class_specialized = class.identity_specialization(db);
let own_class_members: FxHashSet<_> = all_members_of_scope(db, class.body_scope(db)).collect();
let scope = class.body_scope(db);
let own_class_members: FxHashSet<_> = all_members_of_scope(db, scope).collect();
for member in own_class_members {
check_class_declaration(context, configuration, class_specialized, &member);
check_class_declaration(context, configuration, class_specialized, scope, &member);
}
}
@ -64,6 +65,7 @@ fn check_class_declaration<'db>(
context: &InferContext<'db, '_>,
configuration: OverrideRulesConfig,
class: ClassType<'db>,
class_scope: ScopeId<'db>,
member: &MemberWithDefinition<'db>,
) {
/// Salsa-tracked query to check whether any of the definitions of a symbol
@ -103,11 +105,10 @@ fn check_class_declaration<'db>(
let db = context.db();
let MemberWithDefinition { member, definition } = member;
let Some(definition) = definition else {
return;
};
let MemberWithDefinition {
member,
first_reachable_definition,
} = member;
let Place::Defined(type_on_subclass_instance, _, _) =
Type::instance(db, class).member(db, &member.name).place
@ -126,12 +127,14 @@ fn check_class_declaration<'db>(
if class_kind == Some(CodeGeneratorKind::NamedTuple)
&& configuration.check_prohibited_named_tuple_attrs()
&& PROHIBITED_NAMEDTUPLE_ATTRS.contains(&member.name.as_str())
// accessing `.kind()` here is fine as `definition`
// will always be a definition in the file currently being checked
&& !matches!(definition.kind(db), DefinitionKind::AnnotatedAssignment(_))
&& let Some(symbol_id) = place_table(db, class_scope).symbol_id(&member.name)
&& let Some(bad_definition) = use_def_map(db, class_scope)
.all_reachable_bindings(ScopedPlaceId::Symbol(symbol_id))
.filter_map(|binding| binding.binding.definition())
.find(|def| !matches!(def.kind(db), DefinitionKind::AnnotatedAssignment(_)))
&& let Some(builder) = context.report_lint(
&INVALID_NAMED_TUPLE,
definition.focus_range(db, context.module()),
bad_definition.focus_range(db, context.module()),
)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
@ -187,8 +190,6 @@ fn check_class_declaration<'db>(
.unwrap_or_default();
}
subclass_overrides_superclass_declaration = true;
let Place::Defined(superclass_type, _, _) = Type::instance(db, superclass)
.member(db, &member.name)
.place
@ -197,6 +198,8 @@ fn check_class_declaration<'db>(
break;
};
subclass_overrides_superclass_declaration = true;
if configuration.check_final_method_overridden() {
overridden_final_method = overridden_final_method.or_else(|| {
let superclass_symbol_id = superclass_symbol_id?;
@ -272,7 +275,7 @@ fn check_class_declaration<'db>(
context,
&member.name,
class,
*definition,
*first_reachable_definition,
subclass_function,
superclass,
superclass_type,
@ -308,7 +311,7 @@ fn check_class_declaration<'db>(
&& !has_dynamic_superclass
// accessing `.kind()` here is fine as `definition`
// will always be a definition in the file currently being checked
&& definition.kind(db).is_function_def()
&& first_reachable_definition.kind(db).is_function_def()
{
check_explicit_overrides(context, member, class);
}
@ -317,7 +320,7 @@ fn check_class_declaration<'db>(
report_overridden_final_method(
context,
&member.name,
*definition,
*first_reachable_definition,
member.ty,
superclass,
class,
@ -396,16 +399,16 @@ fn check_explicit_overrides<'db>(
let Some(functions) = underlying_functions else {
return;
};
if !functions
let Some(decorated_function) = functions
.iter()
.any(|function| function.has_known_decorator(db, FunctionDecorators::OVERRIDE))
{
.find(|function| function.has_known_decorator(db, FunctionDecorators::OVERRIDE))
else {
return;
}
};
let function_literal = if context.in_stub() {
functions[0].first_overload_or_implementation(db)
decorated_function.first_overload_or_implementation(db)
} else {
functions[0].literal(db).last_definition(db)
decorated_function.literal(db).last_definition(db)
};
let Some(builder) = context.report_lint(

View File

@ -896,7 +896,10 @@ fn cached_protocol_interface<'db>(
// type narrowing that uses `isinstance()` or `issubclass()` with
// runtime-checkable protocols.
for (symbol_id, bindings) in use_def_map.all_end_of_scope_symbol_bindings() {
let Some(ty) = place_from_bindings(db, bindings).ignore_possibly_undefined() else {
let Some(ty) = place_from_bindings(db, bindings)
.place
.ignore_possibly_undefined()
else {
continue;
};
direct_members.insert(

View File

@ -365,7 +365,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
};
let add_item_definition_subdiagnostic = |diagnostic: &mut Diagnostic, message| {
if let Some(declaration) = item.single_declaration {
if let Some(declaration) = item.first_declaration {
let file = declaration.file(db);
let module = parsed_module(db, file).load(db);