From ab1ac254d9750d23e489331e2a5c6ea08172632b Mon Sep 17 00:00:00 2001 From: Jack O'Connor Date: Tue, 6 Jan 2026 09:32:22 -0800 Subject: [PATCH] [ty] fix comparisons and arithmetic with `NewType`s of `float` (#22105) Fixes https://github.com/astral-sh/ty/issues/2077. --- .../resources/mdtest/annotations/new_types.md | 111 ++++++++++++++++++ .../src/types/infer/builder.rs | 74 +++++++++++- .../ty_python_semantic/src/types/relation.rs | 88 +++++++++++--- 3 files changed, 254 insertions(+), 19 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/new_types.md b/crates/ty_python_semantic/resources/mdtest/annotations/new_types.md index ab55503691..c6b4af5f52 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/new_types.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/new_types.md @@ -168,6 +168,56 @@ on top of that: Foo = NewType("Foo", 42) ``` +## `NewType`s in arithmetic and comparison expressions might or might not act as their base + +These expressions are valid because `Foo` acts as its base type, `int`: + +```py +from typing import NewType + +Foo = NewType("Foo", int) + +reveal_type(Foo(42) + 1) # revealed: int +reveal_type(1 + Foo(42)) # revealed: int +reveal_type(Foo(42) + Foo(42)) # revealed: int +reveal_type(Foo(42) == 42) # revealed: bool +reveal_type(42 == Foo(42)) # revealed: bool +reveal_type(Foo(42) == Foo(42)) # revealed: bool +``` + +However, we can't always substitute `int` for `Foo` to evaluate expressions like these. In the +following cases, only `Foo` itself is valid: + +```py +class Bar: + def __add__(self, other: Foo) -> Foo: + return other + + def __radd__(self, other: Foo) -> Foo: + return other + + def __lt__(self, other: Foo) -> bool: + return True + + def __gt__(self, other: Foo) -> bool: + return True + + def __contains__(self, key: Foo) -> bool: + return True + +reveal_type(Foo(42) + Bar()) # revealed: Foo +reveal_type(Bar() + Foo(42)) # revealed: Foo +reveal_type(Foo(42) < Bar()) # revealed: bool +reveal_type(Bar() < Foo(42)) # revealed: bool +reveal_type(Foo(42) in Bar()) # revealed: bool + +42 + Bar() # error: [unsupported-operator] +Bar() + 42 # error: [unsupported-operator] +42 < Bar() # error: [unsupported-operator] +Bar() < 42 # error: [unsupported-operator] +42 in Bar() # error: [unsupported-operator] +``` + ## `float` and `complex` special cases `float` and `complex` are subject to a special case in the typing spec, which we currently interpret @@ -178,6 +228,7 @@ and we accept the unions they expand into. ```py from typing import NewType +from ty_extensions import static_assert, is_assignable_to Foo = NewType("Foo", float) Foo(3.14) @@ -186,6 +237,15 @@ Foo("hello") # error: [invalid-argument-type] "Argument is incorrect: Expected reveal_type(Foo(3.14).__class__) # revealed: type[int] | type[float] reveal_type(Foo(42).__class__) # revealed: type[int] | type[float] +static_assert(is_assignable_to(Foo, float)) +static_assert(is_assignable_to(Foo, int | float)) +static_assert(is_assignable_to(Foo, int | float | None)) +# The assignments above require treating `Foo` as its underlying union type. Each of its members is +# assignable to the union on the right, so `Foo` is assignable to the union, even though `Foo` as a +# whole isn't assignable to any one member. However, as in the previous section, we need to be sure +# that this treatment doesn't break cases like the assignment below, where `Foo` *is* assignable to +# the union on the right, even though its members *aren't*. +static_assert(is_assignable_to(Foo, Foo | None)) Bar = NewType("Bar", complex) Bar(1 + 2j) @@ -196,6 +256,11 @@ Bar("goodbye") # error: [invalid-argument-type] reveal_type(Bar(1 + 2j).__class__) # revealed: type[int] | type[float] | type[complex] reveal_type(Bar(3.14).__class__) # revealed: type[int] | type[float] | type[complex] reveal_type(Bar(42).__class__) # revealed: type[int] | type[float] | type[complex] +static_assert(is_assignable_to(Bar, complex)) +static_assert(is_assignable_to(Bar, int | float | complex)) +static_assert(is_assignable_to(Bar, int | float | complex | None)) +# See the `Foo | None` case above. +static_assert(is_assignable_to(Bar, Bar | None)) ``` We don't currently try to distinguish between an implicit union (e.g. `float`) and the equivalent @@ -223,6 +288,52 @@ def g(_: Callable[[int | float | complex], Bar]): ... g(Bar) ``` +The arithmetic and comparison test cases in the previous section used a `NewType` of `int`, but +`NewType`s of `float` and `complex` are more complicated, because their base type is a union, and +that union needs special handling in binary expressions. In these examples, we we need to lower +`Foo` to `int | float` and then check each member of that union _individually_, as we would with an +explicit `Union` on the left side: + +```py +reveal_type(Foo(3.14) < Foo(42)) # revealed: bool +reveal_type(Foo(3.14) == Foo(42)) # revealed: bool +reveal_type(Foo(3.14) + Foo(42)) # revealed: int | float +reveal_type(Foo(3.14) / Foo(42)) # revealed: int | float +``` + +But again as above, we can't _always_ lower `Foo` to `int | float`, because there are also binary +expressions where only `Foo` itself is valid: + +```py +class Bing: + def __add__(self, other: Foo) -> Foo: + return other + + def __radd__(self, other: Foo) -> Foo: + return other + + def __lt__(self, other: Foo) -> bool: + return True + + def __gt__(self, other: Foo) -> bool: + return True + + def __contains__(self, key: Foo) -> bool: + return True + +reveal_type(Foo(3.14) + Bing()) # revealed: Foo +reveal_type(Bing() + Foo(42)) # revealed: Foo +reveal_type(Foo(3.14) < Bing()) # revealed: bool +reveal_type(Bing() < Foo(42)) # revealed: bool +reveal_type(Foo(3.14) in Bing()) # revealed: bool + +3.14 + Bing() # error: [unsupported-operator] +Bing() + 3.14 # error: [unsupported-operator] +3.14 < Bing() # error: [unsupported-operator] +Bing() < 3.14 # error: [unsupported-operator] +3.14 in Bing() # error: [unsupported-operator] +``` + ## A `NewType` definition must be a simple variable assignment ```py diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index ee4ce80c9c..be501ec00c 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -10426,6 +10426,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, ), + // `try_call_bin_op` works for almost all `NewType`s, but not for `NewType`s of `float` + // and `complex`, where the concrete base type is a union. In that case it turns out + // the `self` types of the dunder methods in typeshed don't match, because they don't + // get the same `int | float` and `int | float | complex` special treatment that the + // positional arguments get. In those cases we need to explicitly delegate to the base + // type, so that it hits the `Type::Union` branches above. + (Type::NewTypeInstance(newtype), rhs, _) => { + Type::try_call_bin_op(self.db(), left_ty, op, right_ty) + .map(|outcome| outcome.return_type(self.db())) + .ok() + .or_else(|| { + self.infer_binary_expression_type( + node, + emitted_division_by_zero_diagnostic, + newtype.concrete_base_type(self.db()), + rhs, + op, + ) + }) + } + (lhs, Type::NewTypeInstance(newtype), _) => { + Type::try_call_bin_op(self.db(), left_ty, op, right_ty) + .map(|outcome| outcome.return_type(self.db())) + .ok() + .or_else(|| { + self.infer_binary_expression_type( + node, + emitted_division_by_zero_diagnostic, + lhs, + newtype.concrete_base_type(self.db()), + op, + ) + }) + } + // Non-todo Anys take precedence over Todos (as if we fix this `Todo` in the future, // the result would then become Any or Unknown, respectively). (div @ Type::Dynamic(DynamicType::Divergent(_)), _, _) @@ -10762,8 +10797,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::TypeVar(_) | Type::TypeIs(_) | Type::TypeGuard(_) - | Type::TypedDict(_) - | Type::NewTypeInstance(_), + | Type::TypedDict(_), Type::FunctionLiteral(_) | Type::BooleanLiteral(_) | Type::Callable(..) @@ -10793,8 +10827,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | Type::TypeVar(_) | Type::TypeIs(_) | Type::TypeGuard(_) - | Type::TypedDict(_) - | Type::NewTypeInstance(_), + | Type::TypedDict(_), op, ) => Type::try_call_bin_op(self.db(), left_ty, op, right_ty) .map(|outcome| outcome.return_type(self.db())) @@ -11228,6 +11261,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) })), + // `try_dunder` works for almost all `NewType`s, but not for `NewType`s of `float` and + // `complex`, where the concrete base type is a union. In that case it turns out the + // `self` types of the dunder methods in typeshed don't match, because they don't get + // the same `int | float` and `int | float | complex` special treatment that the + // positional arguments get. In those cases we need to explicitly delegate to the base + // type, so that it hits the `Type::Union` branches above. + (Type::NewTypeInstance(newtype), right) => Some( + try_dunder(self, MemberLookupPolicy::default()).or_else(|_| { + visitor.visit((left, op, right), || { + self.infer_binary_type_comparison( + newtype.concrete_base_type(self.db()), + op, + right, + range, + visitor, + ) + }) + }), + ), + (left, Type::NewTypeInstance(newtype)) => Some( + try_dunder(self, MemberLookupPolicy::default()).or_else(|_| { + visitor.visit((left, op, right), || { + self.infer_binary_type_comparison( + left, + op, + newtype.concrete_base_type(self.db()), + range, + visitor, + ) + }) + }), + ), + (Type::IntLiteral(n), Type::IntLiteral(m)) => Some(match op { ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)), ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)), diff --git a/crates/ty_python_semantic/src/types/relation.rs b/crates/ty_python_semantic/src/types/relation.rs index 7955dfe61d..3a617fe035 100644 --- a/crates/ty_python_semantic/src/types/relation.rs +++ b/crates/ty_python_semantic/src/types/relation.rs @@ -654,6 +654,79 @@ impl<'db> Type<'db> { // `Never` is the bottom type, the empty set. (_, Type::Never) => ConstraintSet::from(false), + (Type::NewTypeInstance(self_newtype), Type::NewTypeInstance(target_newtype)) => { + self_newtype.has_relation_to_impl(db, target_newtype) + } + // In the special cases of `NewType`s of `float` or `complex`, the concrete base type + // can be a union (`int | float` or `int | float | complex`). For that reason, + // `NewType` assignability to a union needs to consider two different cases. It could + // be that we need to treat the `NewType` as the underlying union it's assignable to, + // for example: + // + // ```py + // Foo = NewType("Foo", float) + // static_assert(is_assignable_to(Foo, float | None)) + // ``` + // + // The right side there is equivalent to `int | float | None`, but `Foo` as a whole + // isn't assignable to any of those three types. However, `Foo`s concrete base type is + // `int | float`, which is assignable, because union members on the left side get + // checked individually. On the other hand, we need to be careful not to break the + // following case, where `int | float` is *not* assignable to the right side: + // + // ```py + // static_assert(is_assignable_to(Foo, Foo | None)) + // ``` + // + // To handle both cases, we have to check that *either* `Foo` as a whole is assignable + // (or subtypeable etc.) *or* that its concrete base type is. Note that this match arm + // needs to take precedence over the `Type::Union` arms immediately below. + (Type::NewTypeInstance(self_newtype), Type::Union(union)) => { + // First the normal "assign to union" case, unfortunately duplicated from below. + union + .elements(db) + .iter() + .when_any(db, |&elem_ty| { + self.has_relation_to_impl( + db, + elem_ty, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + }) + // Failing that, if the concrete base type is a union, try delegating to that. + // Otherwise, this would be equivalent to what we just checked, and we + // shouldn't waste time checking it twice. + .or(db, || { + let concrete_base = self_newtype.concrete_base_type(db); + if matches!(concrete_base, Type::Union(_)) { + concrete_base.has_relation_to_impl( + db, + target, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + } else { + ConstraintSet::from(false) + } + }) + } + // All other `NewType` assignments fall back to the concrete base type. + (Type::NewTypeInstance(self_newtype), _) => { + self_newtype.concrete_base_type(db).has_relation_to_impl( + db, + target, + inferable, + relation, + relation_visitor, + disjointness_visitor, + ) + } + (Type::Union(union), _) => union.elements(db).iter().when_all(db, |&elem_ty| { elem_ty.has_relation_to_impl( db, @@ -1305,21 +1378,6 @@ impl<'db> Type<'db> { }) } - (Type::NewTypeInstance(self_newtype), Type::NewTypeInstance(target_newtype)) => { - self_newtype.has_relation_to_impl(db, target_newtype) - } - - (Type::NewTypeInstance(self_newtype), _) => { - self_newtype.concrete_base_type(db).has_relation_to_impl( - db, - target, - inferable, - relation, - relation_visitor, - disjointness_visitor, - ) - } - (Type::PropertyInstance(_), _) => { KnownClass::Property.to_instance(db).has_relation_to_impl( db,