[ty] Allow if type(x) is Y narrowing for types other than class-literal types (#22729)

## Summary

Fixes https://github.com/astral-sh/ty/issues/2565.

This PR adds support for `if type(x) is Y` narrowing where `Y` is a
subclass-of type, type-alias type, or typevar type.

## Test Plan

mdtests
This commit is contained in:
Alex Waygood
2026-01-20 08:23:00 +00:00
committed by GitHub
parent 0cbe2af18b
commit af886b06c9
2 changed files with 135 additions and 6 deletions

View File

@@ -318,3 +318,106 @@ def _(x: object):
if (type(y := x)) is bool:
reveal_type(y) # revealed: bool
```
## Narrowing where the right-hand side is not a class literal
```toml
[environment]
python-version = "3.12"
```
```py
from typing import final
class Foo: ...
def f(x: Foo, y: type[int]):
if type(x) is y:
reveal_type(x) # revealed: Foo & int
else:
reveal_type(x) # revealed: Foo
if type(x) is not y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo & int
@final
class Bar: ...
def g(x: object, y: type[Bar]):
if type(x) is y:
reveal_type(x) # revealed: Bar
else:
# `Bar` is `@final`, so we can do `else`-branch narrowing here
reveal_type(x) # revealed: ~Bar
if type(x) is not y:
reveal_type(x) # revealed: ~Bar
else:
reveal_type(x) # revealed: Bar
def j[T: int](x: Foo, y: type[T]):
if type(x) is y:
reveal_type(x) # revealed: Foo & int
else:
reveal_type(x) # revealed: Foo
if type(x) is not y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo & int
def k[T: type[int]](x: Foo, y: T):
if type(x) is y:
reveal_type(x) # revealed: Foo & int
else:
reveal_type(x) # revealed: Foo
if type(x) is not y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo & int
type IntClassAlias = type[int]
def strange(x: Foo, y: IntClassAlias):
if type(x) is y:
reveal_type(x) # revealed: Foo & int
else:
reveal_type(x) # revealed: Foo
if type(x) is not y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo & int
class Spam[T]: ...
def h(x: Foo, y: type[Spam[int]]):
# no narrowing can occur, because `Spam[int]` is a generic class,
# and `if type(x) is Y` is not a valid operation if `Y` could be
# a generic alias.
if type(x) is y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo
if type(x) is not y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo
def i[T](x: Foo, y: type[Spam[T]]):
# same here: no narrowing can occur
if type(x) is y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo
if type(x) is not y:
reveal_type(x) # revealed: Foo
else:
reveal_type(x) # revealed: Foo
```

View File

@@ -996,6 +996,32 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
)
}
/// Attempt to find an underlying class literal for purposes of `if type(x) is Y` narrowing.
///
/// We deliberately return `None` for generic-alias types, since narrowing based
/// on `if type(x) is Y[int]` isn't valid (this expression will never return `true`
/// at runtime). Similarly, we return `None` for `type[Y[int]]`, type variables
/// bound to `type[Y[int]]`, and type aliases where the underlying value is a
/// generic class.
fn find_underlying_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<ClassLiteral<'db>> {
match ty {
Type::ClassLiteral(class) => Some(class),
Type::SubclassOf(subclass_of) => {
match subclass_of.subclass_of().with_transposed_type_var(db) {
SubclassOfInner::Class(ClassType::NonGeneric(class)) => Some(class),
SubclassOfInner::Class(ClassType::Generic(_))
| SubclassOfInner::Dynamic(_) => None,
SubclassOfInner::TypeVar(tvar) => {
find_underlying_class(db, tvar.typevar(db).upper_bound(db)?)
}
}
}
Type::TypeVar(tvar) => find_underlying_class(db, tvar.typevar(db).upper_bound(db)?),
Type::TypeAlias(alias) => find_underlying_class(db, alias.value_type(db)),
_ => None,
}
}
let ast::ExprCompare {
range: _,
node_index: _,
@@ -1196,8 +1222,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
// - `if type(x) is not Y`
// - `if Y is type(x)`
// - `if Y is not type(x)`
if let (ast::Expr::Call(call), _, _, Type::ClassLiteral(class))
| (_, Type::ClassLiteral(class), ast::Expr::Call(call), _) =
if let (ast::Expr::Call(call), _, _, other) | (_, other, ast::Expr::Call(call), _) =
(left, lhs_ty, right, rhs_ty)
{
let ast::ExprCall {
@@ -1226,17 +1251,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&& keywords.is_empty()
&& let [single_argument] = &**args
&& let Some(target) = PlaceExpr::try_from_expr(single_argument)
// `else`-branch narrowing for `if type(x) is Y` can only be done
// if `Y` is a final class
&& (is_positive || class.is_final(self.db))
&& let Type::ClassLiteral(called_class) = inference.expression_type(func)
&& called_class.is_known(self.db, KnownClass::Type)
&& let Some(other_class) = find_underlying_class(self.db, other)
// `else`-branch narrowing for `if type(x) is Y` can only be done
// if `Y` is a final class
&& (is_positive || other_class.is_final(self.db))
{
let place = self.expect_place(&target);
constraints.insert(
place,
NarrowingConstraint::intersection(
Type::instance(self.db, class.top_materialization(self.db))
Type::instance(self.db, other_class.top_materialization(self.db))
.negate_if(self.db, !is_positive),
),
);