[ty] fix comparisons and arithmetic with NewTypes of float (#22105)

Fixes https://github.com/astral-sh/ty/issues/2077.
This commit is contained in:
Jack O'Connor
2026-01-06 09:32:22 -08:00
committed by GitHub
parent 01de8bef3e
commit ab1ac254d9
3 changed files with 254 additions and 19 deletions

View File

@@ -168,6 +168,56 @@ on top of that:
Foo = NewType("Foo", 42)
```
## `NewType`s in arithmetic and comparison expressions might or might not act as their base
These expressions are valid because `Foo` acts as its base type, `int`:
```py
from typing import NewType
Foo = NewType("Foo", int)
reveal_type(Foo(42) + 1) # revealed: int
reveal_type(1 + Foo(42)) # revealed: int
reveal_type(Foo(42) + Foo(42)) # revealed: int
reveal_type(Foo(42) == 42) # revealed: bool
reveal_type(42 == Foo(42)) # revealed: bool
reveal_type(Foo(42) == Foo(42)) # revealed: bool
```
However, we can't always substitute `int` for `Foo` to evaluate expressions like these. In the
following cases, only `Foo` itself is valid:
```py
class Bar:
def __add__(self, other: Foo) -> Foo:
return other
def __radd__(self, other: Foo) -> Foo:
return other
def __lt__(self, other: Foo) -> bool:
return True
def __gt__(self, other: Foo) -> bool:
return True
def __contains__(self, key: Foo) -> bool:
return True
reveal_type(Foo(42) + Bar()) # revealed: Foo
reveal_type(Bar() + Foo(42)) # revealed: Foo
reveal_type(Foo(42) < Bar()) # revealed: bool
reveal_type(Bar() < Foo(42)) # revealed: bool
reveal_type(Foo(42) in Bar()) # revealed: bool
42 + Bar() # error: [unsupported-operator]
Bar() + 42 # error: [unsupported-operator]
42 < Bar() # error: [unsupported-operator]
Bar() < 42 # error: [unsupported-operator]
42 in Bar() # error: [unsupported-operator]
```
## `float` and `complex` special cases
`float` and `complex` are subject to a special case in the typing spec, which we currently interpret
@@ -178,6 +228,7 @@ and we accept the unions they expand into.
```py
from typing import NewType
from ty_extensions import static_assert, is_assignable_to
Foo = NewType("Foo", float)
Foo(3.14)
@@ -186,6 +237,15 @@ Foo("hello") # error: [invalid-argument-type] "Argument is incorrect: Expected
reveal_type(Foo(3.14).__class__) # revealed: type[int] | type[float]
reveal_type(Foo(42).__class__) # revealed: type[int] | type[float]
static_assert(is_assignable_to(Foo, float))
static_assert(is_assignable_to(Foo, int | float))
static_assert(is_assignable_to(Foo, int | float | None))
# The assignments above require treating `Foo` as its underlying union type. Each of its members is
# assignable to the union on the right, so `Foo` is assignable to the union, even though `Foo` as a
# whole isn't assignable to any one member. However, as in the previous section, we need to be sure
# that this treatment doesn't break cases like the assignment below, where `Foo` *is* assignable to
# the union on the right, even though its members *aren't*.
static_assert(is_assignable_to(Foo, Foo | None))
Bar = NewType("Bar", complex)
Bar(1 + 2j)
@@ -196,6 +256,11 @@ Bar("goodbye") # error: [invalid-argument-type]
reveal_type(Bar(1 + 2j).__class__) # revealed: type[int] | type[float] | type[complex]
reveal_type(Bar(3.14).__class__) # revealed: type[int] | type[float] | type[complex]
reveal_type(Bar(42).__class__) # revealed: type[int] | type[float] | type[complex]
static_assert(is_assignable_to(Bar, complex))
static_assert(is_assignable_to(Bar, int | float | complex))
static_assert(is_assignable_to(Bar, int | float | complex | None))
# See the `Foo | None` case above.
static_assert(is_assignable_to(Bar, Bar | None))
```
We don't currently try to distinguish between an implicit union (e.g. `float`) and the equivalent
@@ -223,6 +288,52 @@ def g(_: Callable[[int | float | complex], Bar]): ...
g(Bar)
```
The arithmetic and comparison test cases in the previous section used a `NewType` of `int`, but
`NewType`s of `float` and `complex` are more complicated, because their base type is a union, and
that union needs special handling in binary expressions. In these examples, we we need to lower
`Foo` to `int | float` and then check each member of that union _individually_, as we would with an
explicit `Union` on the left side:
```py
reveal_type(Foo(3.14) < Foo(42)) # revealed: bool
reveal_type(Foo(3.14) == Foo(42)) # revealed: bool
reveal_type(Foo(3.14) + Foo(42)) # revealed: int | float
reveal_type(Foo(3.14) / Foo(42)) # revealed: int | float
```
But again as above, we can't _always_ lower `Foo` to `int | float`, because there are also binary
expressions where only `Foo` itself is valid:
```py
class Bing:
def __add__(self, other: Foo) -> Foo:
return other
def __radd__(self, other: Foo) -> Foo:
return other
def __lt__(self, other: Foo) -> bool:
return True
def __gt__(self, other: Foo) -> bool:
return True
def __contains__(self, key: Foo) -> bool:
return True
reveal_type(Foo(3.14) + Bing()) # revealed: Foo
reveal_type(Bing() + Foo(42)) # revealed: Foo
reveal_type(Foo(3.14) < Bing()) # revealed: bool
reveal_type(Bing() < Foo(42)) # revealed: bool
reveal_type(Foo(3.14) in Bing()) # revealed: bool
3.14 + Bing() # error: [unsupported-operator]
Bing() + 3.14 # error: [unsupported-operator]
3.14 < Bing() # error: [unsupported-operator]
Bing() < 3.14 # error: [unsupported-operator]
3.14 in Bing() # error: [unsupported-operator]
```
## A `NewType` definition must be a simple variable assignment
```py

View File

@@ -10426,6 +10426,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op,
),
// `try_call_bin_op` works for almost all `NewType`s, but not for `NewType`s of `float`
// and `complex`, where the concrete base type is a union. In that case it turns out
// the `self` types of the dunder methods in typeshed don't match, because they don't
// get the same `int | float` and `int | float | complex` special treatment that the
// positional arguments get. In those cases we need to explicitly delegate to the base
// type, so that it hits the `Type::Union` branches above.
(Type::NewTypeInstance(newtype), rhs, _) => {
Type::try_call_bin_op(self.db(), left_ty, op, right_ty)
.map(|outcome| outcome.return_type(self.db()))
.ok()
.or_else(|| {
self.infer_binary_expression_type(
node,
emitted_division_by_zero_diagnostic,
newtype.concrete_base_type(self.db()),
rhs,
op,
)
})
}
(lhs, Type::NewTypeInstance(newtype), _) => {
Type::try_call_bin_op(self.db(), left_ty, op, right_ty)
.map(|outcome| outcome.return_type(self.db()))
.ok()
.or_else(|| {
self.infer_binary_expression_type(
node,
emitted_division_by_zero_diagnostic,
lhs,
newtype.concrete_base_type(self.db()),
op,
)
})
}
// Non-todo Anys take precedence over Todos (as if we fix this `Todo` in the future,
// the result would then become Any or Unknown, respectively).
(div @ Type::Dynamic(DynamicType::Divergent(_)), _, _)
@@ -10762,8 +10797,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::TypeVar(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_),
| Type::TypedDict(_),
Type::FunctionLiteral(_)
| Type::BooleanLiteral(_)
| Type::Callable(..)
@@ -10793,8 +10827,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::TypeVar(_)
| Type::TypeIs(_)
| Type::TypeGuard(_)
| Type::TypedDict(_)
| Type::NewTypeInstance(_),
| Type::TypedDict(_),
op,
) => Type::try_call_bin_op(self.db(), left_ty, op, right_ty)
.map(|outcome| outcome.return_type(self.db()))
@@ -11228,6 +11261,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
)
})),
// `try_dunder` works for almost all `NewType`s, but not for `NewType`s of `float` and
// `complex`, where the concrete base type is a union. In that case it turns out the
// `self` types of the dunder methods in typeshed don't match, because they don't get
// the same `int | float` and `int | float | complex` special treatment that the
// positional arguments get. In those cases we need to explicitly delegate to the base
// type, so that it hits the `Type::Union` branches above.
(Type::NewTypeInstance(newtype), right) => Some(
try_dunder(self, MemberLookupPolicy::default()).or_else(|_| {
visitor.visit((left, op, right), || {
self.infer_binary_type_comparison(
newtype.concrete_base_type(self.db()),
op,
right,
range,
visitor,
)
})
}),
),
(left, Type::NewTypeInstance(newtype)) => Some(
try_dunder(self, MemberLookupPolicy::default()).or_else(|_| {
visitor.visit((left, op, right), || {
self.infer_binary_type_comparison(
left,
op,
newtype.concrete_base_type(self.db()),
range,
visitor,
)
})
}),
),
(Type::IntLiteral(n), Type::IntLiteral(m)) => Some(match op {
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),

View File

@@ -654,6 +654,79 @@ impl<'db> Type<'db> {
// `Never` is the bottom type, the empty set.
(_, Type::Never) => ConstraintSet::from(false),
(Type::NewTypeInstance(self_newtype), Type::NewTypeInstance(target_newtype)) => {
self_newtype.has_relation_to_impl(db, target_newtype)
}
// In the special cases of `NewType`s of `float` or `complex`, the concrete base type
// can be a union (`int | float` or `int | float | complex`). For that reason,
// `NewType` assignability to a union needs to consider two different cases. It could
// be that we need to treat the `NewType` as the underlying union it's assignable to,
// for example:
//
// ```py
// Foo = NewType("Foo", float)
// static_assert(is_assignable_to(Foo, float | None))
// ```
//
// The right side there is equivalent to `int | float | None`, but `Foo` as a whole
// isn't assignable to any of those three types. However, `Foo`s concrete base type is
// `int | float`, which is assignable, because union members on the left side get
// checked individually. On the other hand, we need to be careful not to break the
// following case, where `int | float` is *not* assignable to the right side:
//
// ```py
// static_assert(is_assignable_to(Foo, Foo | None))
// ```
//
// To handle both cases, we have to check that *either* `Foo` as a whole is assignable
// (or subtypeable etc.) *or* that its concrete base type is. Note that this match arm
// needs to take precedence over the `Type::Union` arms immediately below.
(Type::NewTypeInstance(self_newtype), Type::Union(union)) => {
// First the normal "assign to union" case, unfortunately duplicated from below.
union
.elements(db)
.iter()
.when_any(db, |&elem_ty| {
self.has_relation_to_impl(
db,
elem_ty,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
})
// Failing that, if the concrete base type is a union, try delegating to that.
// Otherwise, this would be equivalent to what we just checked, and we
// shouldn't waste time checking it twice.
.or(db, || {
let concrete_base = self_newtype.concrete_base_type(db);
if matches!(concrete_base, Type::Union(_)) {
concrete_base.has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
} else {
ConstraintSet::from(false)
}
})
}
// All other `NewType` assignments fall back to the concrete base type.
(Type::NewTypeInstance(self_newtype), _) => {
self_newtype.concrete_base_type(db).has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}
(Type::Union(union), _) => union.elements(db).iter().when_all(db, |&elem_ty| {
elem_ty.has_relation_to_impl(
db,
@@ -1305,21 +1378,6 @@ impl<'db> Type<'db> {
})
}
(Type::NewTypeInstance(self_newtype), Type::NewTypeInstance(target_newtype)) => {
self_newtype.has_relation_to_impl(db, target_newtype)
}
(Type::NewTypeInstance(self_newtype), _) => {
self_newtype.concrete_base_type(db).has_relation_to_impl(
db,
target,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}
(Type::PropertyInstance(_), _) => {
KnownClass::Property.to_instance(db).has_relation_to_impl(
db,