[ty] Eagerly evaluate reachability constraints using 'self' to Ambigous

This commit is contained in:
David Peter 2025-08-19 14:58:55 +02:00
parent 600245478c
commit 1ef3a982fb
4 changed files with 56 additions and 21 deletions

View File

@ -94,6 +94,7 @@ pub(super) struct SemanticIndexBuilder<'db, 'ast> {
has_future_annotations: bool, has_future_annotations: bool,
/// Whether we are currently visiting an `if TYPE_CHECKING` block. /// Whether we are currently visiting an `if TYPE_CHECKING` block.
in_type_checking_block: bool, in_type_checking_block: bool,
used_self_in_expr: bool,
// Used for checking semantic syntax errors // Used for checking semantic syntax errors
python_version: PythonVersion, python_version: PythonVersion,
@ -135,6 +136,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
has_future_annotations: false, has_future_annotations: false,
in_type_checking_block: false, in_type_checking_block: false,
used_self_in_expr: false,
scopes: IndexVec::new(), scopes: IndexVec::new(),
place_tables: IndexVec::new(), place_tables: IndexVec::new(),
@ -214,6 +216,12 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
} }
} }
fn visit_expr_and_detect_self_usage(&mut self, expr: &'ast ast::Expr) -> bool {
self.used_self_in_expr = false;
self.visit_expr(expr);
self.used_self_in_expr
}
/// Push a new loop, returning the outer loop, if any. /// Push a new loop, returning the outer loop, if any.
fn push_loop(&mut self) -> Option<Loop> { fn push_loop(&mut self) -> Option<Loop> {
self.current_scope_info_mut() self.current_scope_info_mut()
@ -668,14 +676,19 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
fn record_expression_narrowing_constraint( fn record_expression_narrowing_constraint(
&mut self, &mut self,
precide_node: &ast::Expr, predicate_node: &ast::Expr,
predicate_used_self: bool,
) -> PredicateOrLiteral<'db> { ) -> PredicateOrLiteral<'db> {
let predicate = self.build_predicate(precide_node); let predicate = self.build_predicate(predicate_node, predicate_used_self);
self.record_narrowing_constraint(predicate); self.record_narrowing_constraint(predicate);
predicate predicate
} }
fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> { fn build_predicate(
&mut self,
predicate_node: &ast::Expr,
predicate_used_self: bool,
) -> PredicateOrLiteral<'db> {
// Some commonly used test expressions are eagerly evaluated as `true` // Some commonly used test expressions are eagerly evaluated as `true`
// or `false` here for performance reasons. This list does not need to // or `false` here for performance reasons. This list does not need to
// be exhaustive. More complex expressions will still evaluate to the // be exhaustive. More complex expressions will still evaluate to the
@ -701,6 +714,10 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
let expression = self.add_standalone_expression(predicate_node); let expression = self.add_standalone_expression(predicate_node);
if predicate_used_self {
return PredicateOrLiteral::Ambiguous;
}
match resolve_to_literal(predicate_node) { match resolve_to_literal(predicate_node) {
Some(literal) => PredicateOrLiteral::Literal(literal), Some(literal) => PredicateOrLiteral::Literal(literal),
None => PredicateOrLiteral::Predicate(Predicate { None => PredicateOrLiteral::Predicate(Predicate {
@ -1042,8 +1059,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
); );
for if_expr in &generator.ifs { for if_expr in &generator.ifs {
self.visit_expr(if_expr); let self_usage = self.visit_expr_and_detect_self_usage(if_expr);
self.record_expression_narrowing_constraint(if_expr); self.record_expression_narrowing_constraint(if_expr, self_usage);
} }
for generator in generators_iter { for generator in generators_iter {
@ -1060,8 +1077,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
); );
for if_expr in &generator.ifs { for if_expr in &generator.ifs {
self.visit_expr(if_expr); let self_usage = self.visit_expr_and_detect_self_usage(if_expr);
self.record_expression_narrowing_constraint(if_expr); self.record_expression_narrowing_constraint(if_expr, self_usage);
} }
} }
@ -1557,8 +1574,8 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
// flow states and simplification of reachability constraints, since there is no way // flow states and simplification of reachability constraints, since there is no way
// of getting out of that `msg` branch. We simply restore to the post-test state. // of getting out of that `msg` branch. We simply restore to the post-test state.
self.visit_expr(test); let self_usage = self.visit_expr_and_detect_self_usage(test);
let predicate = self.build_predicate(test); let predicate = self.build_predicate(test, self_usage);
if let Some(msg) = msg { if let Some(msg) = msg {
let post_test = self.flow_snapshot(); let post_test = self.flow_snapshot();
@ -1671,9 +1688,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
} }
} }
ast::Stmt::If(node) => { ast::Stmt::If(node) => {
self.visit_expr(&node.test); let self_usage = self.visit_expr_and_detect_self_usage(&node.test);
let mut no_branch_taken = self.flow_snapshot(); let mut no_branch_taken = self.flow_snapshot();
let mut last_predicate = self.record_expression_narrowing_constraint(&node.test); let mut last_predicate =
self.record_expression_narrowing_constraint(&node.test, self_usage);
let mut last_reachability_constraint = let mut last_reachability_constraint =
self.record_reachability_constraint(last_predicate); self.record_reachability_constraint(last_predicate);
@ -1718,11 +1736,12 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
self.record_negated_reachability_constraint(last_reachability_constraint); self.record_negated_reachability_constraint(last_reachability_constraint);
if let Some(elif_test) = clause_test { if let Some(elif_test) = clause_test {
self.visit_expr(elif_test); let self_usage = self.visit_expr_and_detect_self_usage(elif_test);
// A test expression is evaluated whether the branch is taken or not // A test expression is evaluated whether the branch is taken or not
no_branch_taken = self.flow_snapshot(); no_branch_taken = self.flow_snapshot();
last_predicate = self.record_expression_narrowing_constraint(elif_test); last_predicate =
self.record_expression_narrowing_constraint(elif_test, self_usage);
last_reachability_constraint = last_reachability_constraint =
self.record_reachability_constraint(last_predicate); self.record_reachability_constraint(last_predicate);
@ -1764,10 +1783,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
range: _, range: _,
node_index: _, node_index: _,
}) => { }) => {
self.visit_expr(test); let self_usage = self.visit_expr_and_detect_self_usage(test);
let pre_loop = self.flow_snapshot(); let pre_loop = self.flow_snapshot();
let predicate = self.record_expression_narrowing_constraint(test); let predicate = self.record_expression_narrowing_constraint(test, self_usage);
self.record_reachability_constraint(predicate); self.record_reachability_constraint(predicate);
let outer_loop = self.push_loop(); let outer_loop = self.push_loop();
@ -2255,6 +2274,13 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
let node_key = NodeKey::from_node(expr); let node_key = NodeKey::from_node(expr);
if let ast::Expr::Name(ast::ExprName { id, .. }) = expr {
// TODO: proper detection of first parameter of methods
if id == "self" {
self.used_self_in_expr = true;
}
}
match expr { match expr {
ast::Expr::Name(ast::ExprName { ctx, .. }) ast::Expr::Name(ast::ExprName { ctx, .. })
| ast::Expr::Attribute(ast::ExprAttribute { ctx, .. }) | ast::Expr::Attribute(ast::ExprAttribute { ctx, .. })
@ -2431,9 +2457,9 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
ast::Expr::If(ast::ExprIf { ast::Expr::If(ast::ExprIf {
body, test, orelse, .. body, test, orelse, ..
}) => { }) => {
self.visit_expr(test); let self_usage = self.visit_expr_and_detect_self_usage(test);
let pre_if = self.flow_snapshot(); let pre_if = self.flow_snapshot();
let predicate = self.record_expression_narrowing_constraint(test); let predicate = self.record_expression_narrowing_constraint(test, self_usage);
let reachability_constraint = self.record_reachability_constraint(predicate); let reachability_constraint = self.record_reachability_constraint(predicate);
self.visit_expr(body); self.visit_expr(body);
let post_body = self.flow_snapshot(); let post_body = self.flow_snapshot();
@ -2509,12 +2535,12 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
.record_reachability_constraint(*id); // TODO: nicer API .record_reachability_constraint(*id); // TODO: nicer API
} }
self.visit_expr(value); let self_usage = self.visit_expr_and_detect_self_usage(value);
// For the last value, we don't need to model control flow. There is no short-circuiting // For the last value, we don't need to model control flow. There is no short-circuiting
// anymore. // anymore.
if index < values.len() - 1 { if index < values.len() - 1 {
let predicate = self.build_predicate(value); let predicate = self.build_predicate(value, self_usage);
let predicate_id = match op { let predicate_id = match op {
ast::BoolOp::And => self.add_predicate(predicate), ast::BoolOp::And => self.add_predicate(predicate),
ast::BoolOp::Or => self.add_negated_predicate(predicate), ast::BoolOp::Or => self.add_negated_predicate(predicate),

View File

@ -28,7 +28,10 @@ impl ScopedPredicateId {
/// A special ID that is used for an "always false" predicate. /// A special ID that is used for an "always false" predicate.
pub(crate) const ALWAYS_FALSE: ScopedPredicateId = ScopedPredicateId(0xffff_fffe); pub(crate) const ALWAYS_FALSE: ScopedPredicateId = ScopedPredicateId(0xffff_fffe);
const SMALLEST_TERMINAL: ScopedPredicateId = Self::ALWAYS_FALSE; /// A special ID that is used for an "ambiguous" predicate.
pub(crate) const AMBIGUOUS: ScopedPredicateId = ScopedPredicateId(0xffff_fffd);
const SMALLEST_TERMINAL: ScopedPredicateId = Self::AMBIGUOUS;
fn is_terminal(self) -> bool { fn is_terminal(self) -> bool {
self >= Self::SMALLEST_TERMINAL self >= Self::SMALLEST_TERMINAL
@ -86,6 +89,7 @@ pub(crate) struct Predicate<'db> {
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) enum PredicateOrLiteral<'db> { pub(crate) enum PredicateOrLiteral<'db> {
Literal(bool), Literal(bool),
Ambiguous,
Predicate(Predicate<'db>), Predicate(Predicate<'db>),
} }
@ -93,6 +97,7 @@ impl PredicateOrLiteral<'_> {
pub(crate) fn negated(self) -> Self { pub(crate) fn negated(self) -> Self {
match self { match self {
PredicateOrLiteral::Literal(value) => PredicateOrLiteral::Literal(!value), PredicateOrLiteral::Literal(value) => PredicateOrLiteral::Literal(!value),
PredicateOrLiteral::Ambiguous => PredicateOrLiteral::Ambiguous,
PredicateOrLiteral::Predicate(Predicate { node, is_positive }) => { PredicateOrLiteral::Predicate(Predicate { node, is_positive }) => {
PredicateOrLiteral::Predicate(Predicate { PredicateOrLiteral::Predicate(Predicate {
node, node,

View File

@ -477,6 +477,8 @@ impl ReachabilityConstraintsBuilder {
ScopedReachabilityConstraintId::ALWAYS_FALSE ScopedReachabilityConstraintId::ALWAYS_FALSE
} else if predicate == ScopedPredicateId::ALWAYS_TRUE { } else if predicate == ScopedPredicateId::ALWAYS_TRUE {
ScopedReachabilityConstraintId::ALWAYS_TRUE ScopedReachabilityConstraintId::ALWAYS_TRUE
} else if predicate == ScopedPredicateId::AMBIGUOUS {
ScopedReachabilityConstraintId::AMBIGUOUS
} else { } else {
self.add_interior(InteriorNode { self.add_interior(InteriorNode {
atom: predicate, atom: predicate,

View File

@ -981,14 +981,16 @@ impl<'db> UseDefMapBuilder<'db> {
PredicateOrLiteral::Predicate(predicate) => self.predicates.add_predicate(predicate), PredicateOrLiteral::Predicate(predicate) => self.predicates.add_predicate(predicate),
PredicateOrLiteral::Literal(true) => ScopedPredicateId::ALWAYS_TRUE, PredicateOrLiteral::Literal(true) => ScopedPredicateId::ALWAYS_TRUE,
PredicateOrLiteral::Literal(false) => ScopedPredicateId::ALWAYS_FALSE, PredicateOrLiteral::Literal(false) => ScopedPredicateId::ALWAYS_FALSE,
PredicateOrLiteral::Ambiguous => ScopedPredicateId::AMBIGUOUS,
} }
} }
pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) { pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) {
if predicate == ScopedPredicateId::ALWAYS_TRUE if predicate == ScopedPredicateId::ALWAYS_TRUE
|| predicate == ScopedPredicateId::ALWAYS_FALSE || predicate == ScopedPredicateId::ALWAYS_FALSE
|| predicate == ScopedPredicateId::AMBIGUOUS
{ {
// No need to record a narrowing constraint for `True` or `False`. // No need to record a narrowing constraint
return; return;
} }