diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md index fbbe5794d8..1651a0a31a 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md @@ -49,3 +49,88 @@ def _(x: int | None): reveal_type(x) # revealed: int ``` + +## Narrowing after a NoReturn call in one branch + +When a branch calls a function that returns `NoReturn`/`Never`, we know that branch terminates and +doesn't contribute to the type after the if statement. + +```py +import sys + +def _(val: int | None): + if val is None: + sys.exit() + # After the if statement, val cannot be None because that case + # would have called sys.exit() and never reached here + reveal_type(val) # revealed: int +``` + +This also works when the NoReturn function is called in the else branch: + +```py +import sys + +def _(val: int | None): + if val is not None: + pass + else: + sys.exit() + reveal_type(val) # revealed: int +``` + +And for elif branches: + +```py +import sys + +def _(val: int | str | None): + if val is None: + sys.exit() + elif isinstance(val, int): + pass + else: + sys.exit() + # TODO: Should be `int`, but we don't yet fully support narrowing after NoReturn in elif chains + reveal_type(val) # revealed: int | str +``` + +## Narrowing from assert should not affect reassigned variables + +When a variable is reassigned after an `assert`, the narrowing from the assert should not apply to +the new value. The assert condition was about the old value, not the new one. + +```py +def foo(arg: int) -> int | None: + return None + +def bar() -> None: + v = foo(1) + assert v is None + + v = foo(2) + # v was reassigned, so the assert narrowing shouldn't apply + reveal_type(v) # revealed: int | None +``` + +## Narrowing from NoReturn should not affect reassigned variables + +Similar to assert, when a variable is narrowed due to a NoReturn call in one branch and then +reassigned, the narrowing should only apply before the reassignment, not after. + +```py +import sys + +def foo() -> int | None: + return 3 + +def bar(): + v = foo() + if v is None: + sys.exit() + reveal_type(v) # revealed: int + + v = foo() + # v was reassigned, so the NoReturn narrowing shouldn't apply + reveal_type(v) # revealed: int | None +``` diff --git a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md index 25d458ae67..05bf71894e 100644 --- a/crates/ty_python_semantic/resources/mdtest/terminal_statements.md +++ b/crates/ty_python_semantic/resources/mdtest/terminal_statements.md @@ -618,9 +618,7 @@ def g(x: int | None): if x is None: sys.exit(1) - # TODO: should be just `int`, not `int | None` - # See https://github.com/astral-sh/ty/issues/685 - reveal_type(x) # revealed: int | None + reveal_type(x) # revealed: int ``` ### Possibly unresolved diagnostics diff --git a/crates/ty_python_semantic/src/place.rs b/crates/ty_python_semantic/src/place.rs index 21dfb955a9..5d042f4c8e 100644 --- a/crates/ty_python_semantic/src/place.rs +++ b/crates/ty_python_semantic/src/place.rs @@ -1127,7 +1127,21 @@ fn place_from_bindings_impl<'db>( first_definition.get_or_insert(binding); let binding_ty = binding_type(db, binding); - Some(narrowing_constraint.narrow(db, binding_ty, binding.place(db))) + let place_id = binding.place(db); + // Apply explicit narrowing constraints first + let narrowed_ty = narrowing_constraint.narrow(db, binding_ty, place_id); + // Also apply narrowing from the reachability constraint. + // This handles cases like `if x is None: noreturn_func()` where the reachability + // constraint encodes that `x is not None` for the code to be reachable. + let narrowed_ty = reachability_constraints.narrow_by_reachability( + db, + predicates, + reachability_constraint, + place_id, + binding, + narrowed_ty, + ); + Some(narrowed_ty) }, ); diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index 2659e75493..ab3c3cd4d7 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use ruff_db::files::{File, FileRange}; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; -use ruff_python_ast as ast; +use ruff_python_ast::{self as ast, NodeIndex}; use ruff_text_size::{Ranged, TextRange}; use crate::Db; @@ -747,6 +747,37 @@ impl DefinitionKind<'_> { matches!(self, DefinitionKind::Function(_)) } + /// Returns the [`NodeIndex`] of the definition target. + /// + /// This can be used to determine the relative ordering of definitions and expressions + /// in the AST without needing to access the parsed module. + pub(crate) fn target_node_index(&self) -> NodeIndex { + match self { + DefinitionKind::Import(import) => import.node.index(), + DefinitionKind::ImportFrom(import) => import.node.index(), + DefinitionKind::ImportFromSubmodule(import) => import.node.index(), + DefinitionKind::StarImport(import) => import.node.index(), + DefinitionKind::Function(function) => function.index(), + DefinitionKind::Class(class) => class.index(), + DefinitionKind::TypeAlias(type_alias) => type_alias.index(), + DefinitionKind::NamedExpression(named) => named.index(), + DefinitionKind::Assignment(assignment) => assignment.target.index(), + DefinitionKind::AnnotatedAssignment(assign) => assign.target.index(), + DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.index(), + DefinitionKind::For(for_stmt) => for_stmt.target.index(), + DefinitionKind::Comprehension(comp) => comp.target.index(), + DefinitionKind::VariadicPositionalParameter(parameter) => parameter.index(), + DefinitionKind::VariadicKeywordParameter(parameter) => parameter.index(), + DefinitionKind::Parameter(parameter) => parameter.index(), + DefinitionKind::WithItem(with_item) => with_item.target.index(), + DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.index(), + DefinitionKind::ExceptHandler(handler) => handler.handler.index(), + DefinitionKind::TypeVar(type_var) => type_var.index(), + DefinitionKind::ParamSpec(param_spec) => param_spec.index(), + DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.index(), + } + } + /// Returns the [`TextRange`] of the definition target. /// /// A definition target would mainly be the node representing the place being defined i.e., diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 9da8bbe87a..febc645b70 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -209,7 +209,7 @@ use crate::semantic_index::predicate::{ }; use crate::types::{ CallableTypes, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType, - infer_expression_type, static_expression_truthiness, + infer_expression_type, infer_narrowing_constraint, static_expression_truthiness, }; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula @@ -953,4 +953,140 @@ impl ReachabilityConstraints { } } } + + /// Narrow a type based on the predicates in a reachability constraint. + /// + /// When code is only reachable under certain conditions (encoded in the reachability constraint), + /// those conditions can be used to narrow types. For example, if code is only reachable when + /// `val is not None`, we can narrow `val` from `int | None` to `int`. + /// + /// This function traverses the reachability constraint TDD and extracts predicates that can + /// narrow the given place, then applies them to the base type. + /// + /// The `binding` parameter is used to filter out predicates that were created before the + /// binding. Such predicates are about an earlier binding of the same place and shouldn't + /// be used to narrow the current binding. + pub(crate) fn narrow_by_reachability<'db>( + &self, + db: &'db dyn Db, + predicates: &Predicates<'db>, + id: ScopedReachabilityConstraintId, + place: crate::semantic_index::place::ScopedPlaceId, + binding: crate::semantic_index::definition::Definition<'db>, + base_ty: Type<'db>, + ) -> Type<'db> { + // Get the binding's NodeIndex. We'll use this to filter out predicates that were + // created before this binding (and thus are about an earlier binding of the same place). + let binding_index = binding.kind(db).target_node_index(); + + // Traverse the TDD and collect predicates that affect reachability. + // For each predicate, we check if one branch leads to ALWAYS_FALSE while + // the other leads to something reachable. If so, we know the predicate + // must have the value that leads to the reachable branch. + let mut result_ty = base_ty; + let mut current_id = id; + + loop { + let node = match current_id { + ALWAYS_TRUE | AMBIGUOUS | ALWAYS_FALSE => break, + _ => { + let raw_index = current_id.as_u32() as usize; + if !self.used_indices.get_bit(raw_index).unwrap_or(false) { + break; + } + let index = self.used_indices.rank(raw_index) as usize; + self.used_interiors[index] + } + }; + + let predicate = &predicates[node.atom]; + let truthiness = Self::analyze_single(db, predicate); + + // Check if one branch leads to ALWAYS_FALSE (unreachable). + // If so, we can narrow based on the condition that makes the other branch reachable. + // + // For example, in: `if x is None: sys.exit()` + // The predicate `x is None` being true leads to the NoReturn call, + // which makes that branch unreachable. So the code after the if statement + // is only reachable when `x is None` is false, i.e., `x is not None`. + // + // We only apply this narrowing when: + // 1. The predicate is NOT a ReturnsNever predicate (those just mark reachability) + // 2. The predicate evaluates to Ambiguous (we don't know the value statically) + // 3. Exactly one branch leads to ALWAYS_FALSE + if truthiness == Truthiness::Ambiguous + && !matches!(predicate.node, PredicateNode::ReturnsNever(_)) + { + let apply_predicate = + if node.if_true == ALWAYS_FALSE && node.if_false != ALWAYS_FALSE { + // The true branch is unreachable (e.g., `if x is None: sys.exit()`). + // The predicate must be false for this code to be reachable. + Some(Predicate { + node: predicate.node, + is_positive: !predicate.is_positive, + }) + } else if node.if_false == ALWAYS_FALSE && node.if_true != ALWAYS_FALSE { + // The false branch is unreachable. The predicate must be true. + // This handles: `if x is not None: pass else: sys.exit()` and + // `assert x is not None`. + Some(Predicate { + node: predicate.node, + is_positive: predicate.is_positive, + }) + } else { + None + }; + + if let Some(pred) = apply_predicate { + // Check if the predicate was created before the binding. + // If so, skip narrowing because the predicate is about an earlier binding. + let predicate_index = match pred.node { + PredicateNode::Expression(expr) => Some(expr._node_ref(db).index()), + PredicateNode::Pattern(pattern) => { + Some(pattern.subject(db)._node_ref(db).index()) + } + PredicateNode::ReturnsNever(_) + | PredicateNode::StarImportPlaceholder(_) => { + // These predicates don't narrow places + None + } + }; + + // If the predicate was created before the binding, it's about an earlier + // binding of the same place and shouldn't narrow the current binding. + if predicate_index.is_some_and(|idx| idx < binding_index) { + current_id = match truthiness { + Truthiness::AlwaysTrue => node.if_true, + Truthiness::Ambiguous => node.if_ambiguous, + Truthiness::AlwaysFalse => node.if_false, + }; + continue; + } + + if let Some(narrowed) = infer_narrowing_constraint(db, pred, place) { + let candidate = IntersectionBuilder::new(db) + .add_positive(result_ty) + .add_positive(narrowed) + .build(); + // Only apply narrowing if the result is not Never. + // A Never result indicates that the narrowing predicate refers to a + // different value than the current binding (e.g., for assignments where + // the reachability constraint is about the old value, not the new one). + if !candidate.is_equivalent_to(db, Type::Never) { + result_ty = candidate; + } + } + } + } + + // Follow the branch based on the evaluated truthiness + current_id = match truthiness { + Truthiness::AlwaysTrue => node.if_true, + Truthiness::Ambiguous => node.if_ambiguous, + Truthiness::AlwaysFalse => node.if_false, + }; + } + + result_ty + } }