[ty] Narrow types after NoReturn calls in if branches

When a branch calls a NoReturn function, use the negation of the condition
to narrow types after the if statement. For example, after
`if val is None: sys.exit()`, `val` is now correctly narrowed to `int`
instead of `int | None`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Alex Gaynor 2025-12-05 20:24:10 -05:00
parent ef45c97dab
commit 6c73a5cff5
5 changed files with 270 additions and 6 deletions

View File

@ -49,3 +49,88 @@ def _(x: int | None):
reveal_type(x) # revealed: int
```
## Narrowing after a NoReturn call in one branch
When a branch calls a function that returns `NoReturn`/`Never`, we know that branch terminates and
doesn't contribute to the type after the if statement.
```py
import sys
def _(val: int | None):
if val is None:
sys.exit()
# After the if statement, val cannot be None because that case
# would have called sys.exit() and never reached here
reveal_type(val) # revealed: int
```
This also works when the NoReturn function is called in the else branch:
```py
import sys
def _(val: int | None):
if val is not None:
pass
else:
sys.exit()
reveal_type(val) # revealed: int
```
And for elif branches:
```py
import sys
def _(val: int | str | None):
if val is None:
sys.exit()
elif isinstance(val, int):
pass
else:
sys.exit()
# TODO: Should be `int`, but we don't yet fully support narrowing after NoReturn in elif chains
reveal_type(val) # revealed: int | str
```
## Narrowing from assert should not affect reassigned variables
When a variable is reassigned after an `assert`, the narrowing from the assert should not apply to
the new value. The assert condition was about the old value, not the new one.
```py
def foo(arg: int) -> int | None:
return None
def bar() -> None:
v = foo(1)
assert v is None
v = foo(2)
# v was reassigned, so the assert narrowing shouldn't apply
reveal_type(v) # revealed: int | None
```
## Narrowing from NoReturn should not affect reassigned variables
Similar to assert, when a variable is narrowed due to a NoReturn call in one branch and then
reassigned, the narrowing should only apply before the reassignment, not after.
```py
import sys
def foo() -> int | None:
return 3
def bar():
v = foo()
if v is None:
sys.exit()
reveal_type(v) # revealed: int
v = foo()
# v was reassigned, so the NoReturn narrowing shouldn't apply
reveal_type(v) # revealed: int | None
```

View File

@ -618,9 +618,7 @@ def g(x: int | None):
if x is None:
sys.exit(1)
# TODO: should be just `int`, not `int | None`
# See https://github.com/astral-sh/ty/issues/685
reveal_type(x) # revealed: int | None
reveal_type(x) # revealed: int
```
### Possibly unresolved diagnostics

View File

@ -1126,7 +1126,21 @@ fn place_from_bindings_impl<'db>(
first_definition.get_or_insert(binding);
let binding_ty = binding_type(db, binding);
Some(narrowing_constraint.narrow(db, binding_ty, binding.place(db)))
let place_id = binding.place(db);
// Apply explicit narrowing constraints first
let narrowed_ty = narrowing_constraint.narrow(db, binding_ty, place_id);
// Also apply narrowing from the reachability constraint.
// This handles cases like `if x is None: noreturn_func()` where the reachability
// constraint encodes that `x is not None` for the code to be reachable.
let narrowed_ty = reachability_constraints.narrow_by_reachability(
db,
predicates,
reachability_constraint,
place_id,
binding,
narrowed_ty,
);
Some(narrowed_ty)
},
);

View File

@ -2,7 +2,7 @@ use std::ops::Deref;
use ruff_db::files::{File, FileRange};
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast as ast;
use ruff_python_ast::{self as ast, NodeIndex};
use ruff_text_size::{Ranged, TextRange};
use crate::Db;
@ -747,6 +747,37 @@ impl DefinitionKind<'_> {
matches!(self, DefinitionKind::Function(_))
}
/// Returns the [`NodeIndex`] of the definition target.
///
/// This can be used to determine the relative ordering of definitions and expressions
/// in the AST without needing to access the parsed module.
pub(crate) fn target_node_index(&self) -> NodeIndex {
match self {
DefinitionKind::Import(import) => import.node.index(),
DefinitionKind::ImportFrom(import) => import.node.index(),
DefinitionKind::ImportFromSubmodule(import) => import.node.index(),
DefinitionKind::StarImport(import) => import.node.index(),
DefinitionKind::Function(function) => function.index(),
DefinitionKind::Class(class) => class.index(),
DefinitionKind::TypeAlias(type_alias) => type_alias.index(),
DefinitionKind::NamedExpression(named) => named.index(),
DefinitionKind::Assignment(assignment) => assignment.target.index(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.index(),
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.index(),
DefinitionKind::For(for_stmt) => for_stmt.target.index(),
DefinitionKind::Comprehension(comp) => comp.target.index(),
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.index(),
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.index(),
DefinitionKind::Parameter(parameter) => parameter.index(),
DefinitionKind::WithItem(with_item) => with_item.target.index(),
DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.index(),
DefinitionKind::ExceptHandler(handler) => handler.handler.index(),
DefinitionKind::TypeVar(type_var) => type_var.index(),
DefinitionKind::ParamSpec(param_spec) => param_spec.index(),
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.index(),
}
}
/// Returns the [`TextRange`] of the definition target.
///
/// A definition target would mainly be the node representing the place being defined i.e.,

View File

@ -209,7 +209,7 @@ use crate::semantic_index::predicate::{
};
use crate::types::{
CallableTypes, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType,
infer_expression_type, static_expression_truthiness,
infer_expression_type, infer_narrowing_constraint, static_expression_truthiness,
};
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@ -953,4 +953,140 @@ impl ReachabilityConstraints {
}
}
}
/// Narrow a type based on the predicates in a reachability constraint.
///
/// When code is only reachable under certain conditions (encoded in the reachability constraint),
/// those conditions can be used to narrow types. For example, if code is only reachable when
/// `val is not None`, we can narrow `val` from `int | None` to `int`.
///
/// This function traverses the reachability constraint TDD and extracts predicates that can
/// narrow the given place, then applies them to the base type.
///
/// The `binding` parameter is used to filter out predicates that were created before the
/// binding. Such predicates are about an earlier binding of the same place and shouldn't
/// be used to narrow the current binding.
pub(crate) fn narrow_by_reachability<'db>(
&self,
db: &'db dyn Db,
predicates: &Predicates<'db>,
id: ScopedReachabilityConstraintId,
place: crate::semantic_index::place::ScopedPlaceId,
binding: crate::semantic_index::definition::Definition<'db>,
base_ty: Type<'db>,
) -> Type<'db> {
// Get the binding's NodeIndex. We'll use this to filter out predicates that were
// created before this binding (and thus are about an earlier binding of the same place).
let binding_index = binding.kind(db).target_node_index();
// Traverse the TDD and collect predicates that affect reachability.
// For each predicate, we check if one branch leads to ALWAYS_FALSE while
// the other leads to something reachable. If so, we know the predicate
// must have the value that leads to the reachable branch.
let mut result_ty = base_ty;
let mut current_id = id;
loop {
let node = match current_id {
ALWAYS_TRUE | AMBIGUOUS | ALWAYS_FALSE => break,
_ => {
let raw_index = current_id.as_u32() as usize;
if !self.used_indices.get_bit(raw_index).unwrap_or(false) {
break;
}
let index = self.used_indices.rank(raw_index) as usize;
self.used_interiors[index]
}
};
let predicate = &predicates[node.atom];
let truthiness = Self::analyze_single(db, predicate);
// Check if one branch leads to ALWAYS_FALSE (unreachable).
// If so, we can narrow based on the condition that makes the other branch reachable.
//
// For example, in: `if x is None: sys.exit()`
// The predicate `x is None` being true leads to the NoReturn call,
// which makes that branch unreachable. So the code after the if statement
// is only reachable when `x is None` is false, i.e., `x is not None`.
//
// We only apply this narrowing when:
// 1. The predicate is NOT a ReturnsNever predicate (those just mark reachability)
// 2. The predicate evaluates to Ambiguous (we don't know the value statically)
// 3. Exactly one branch leads to ALWAYS_FALSE
if truthiness == Truthiness::Ambiguous
&& !matches!(predicate.node, PredicateNode::ReturnsNever(_))
{
let apply_predicate =
if node.if_true == ALWAYS_FALSE && node.if_false != ALWAYS_FALSE {
// The true branch is unreachable (e.g., `if x is None: sys.exit()`).
// The predicate must be false for this code to be reachable.
Some(Predicate {
node: predicate.node,
is_positive: !predicate.is_positive,
})
} else if node.if_false == ALWAYS_FALSE && node.if_true != ALWAYS_FALSE {
// The false branch is unreachable. The predicate must be true.
// This handles: `if x is not None: pass else: sys.exit()` and
// `assert x is not None`.
Some(Predicate {
node: predicate.node,
is_positive: predicate.is_positive,
})
} else {
None
};
if let Some(pred) = apply_predicate {
// Check if the predicate was created before the binding.
// If so, skip narrowing because the predicate is about an earlier binding.
let predicate_index = match pred.node {
PredicateNode::Expression(expr) => Some(expr._node_ref(db).index()),
PredicateNode::Pattern(pattern) => {
Some(pattern.subject(db)._node_ref(db).index())
}
PredicateNode::ReturnsNever(_)
| PredicateNode::StarImportPlaceholder(_) => {
// These predicates don't narrow places
None
}
};
// If the predicate was created before the binding, it's about an earlier
// binding of the same place and shouldn't narrow the current binding.
if predicate_index.is_some_and(|idx| idx < binding_index) {
current_id = match truthiness {
Truthiness::AlwaysTrue => node.if_true,
Truthiness::Ambiguous => node.if_ambiguous,
Truthiness::AlwaysFalse => node.if_false,
};
continue;
}
if let Some(narrowed) = infer_narrowing_constraint(db, pred, place) {
let candidate = IntersectionBuilder::new(db)
.add_positive(result_ty)
.add_positive(narrowed)
.build();
// Only apply narrowing if the result is not Never.
// A Never result indicates that the narrowing predicate refers to a
// different value than the current binding (e.g., for assignments where
// the reachability constraint is about the old value, not the new one).
if !candidate.is_equivalent_to(db, Type::Never) {
result_ty = candidate;
}
}
}
}
// Follow the branch based on the evaluated truthiness
current_id = match truthiness {
Truthiness::AlwaysTrue => node.if_true,
Truthiness::Ambiguous => node.if_ambiguous,
Truthiness::AlwaysFalse => node.if_false,
};
}
result_ty
}
}