[red-knot] add special case for float/complex (#16166)

When adjusting the existing tests, I aimed to avoid dealing with the
special case in other tests if it's not necessary to do so (that is,
avoid using `float` and `complex` as examples where we just need "some
type"), and keep the tests for the special case mostly collected in the
mdtest dedicated to that purpose.

Fixes https://github.com/astral-sh/ruff/issues/14932
This commit is contained in:
Carl Meyer 2025-02-14 12:24:10 -08:00 committed by GitHub
parent 219712860c
commit dcabb948f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 236 additions and 108 deletions

View File

@ -0,0 +1,90 @@
# Special cases for int/float/complex in annotations
In order to support common use cases, an annotation of `float` actually means `int | float`, and an
annotation of `complex` actually means `int | float | complex`. See
[the specification](https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex)
## float
An annotation of `float` means `int | float`, so `int` is assignable to it:
```py
def takes_float(x: float):
pass
def passes_int_to_float(x: int):
# no error!
takes_float(x)
```
It also applies to variable annotations:
```py
def assigns_int_to_float(x: int):
# no error!
y: float = x
```
It doesn't work the other way around:
```py
def takes_int(x: int):
pass
def passes_float_to_int(x: float):
# error: [invalid-argument-type]
takes_int(x)
def assigns_float_to_int(x: float):
# error: [invalid-assignment]
y: int = x
```
Unlike other type checkers, we choose not to obfuscate this special case by displaying `int | float`
as just `float`; we display the actual type:
```py
def f(x: float):
reveal_type(x) # revealed: int | float
```
## complex
An annotation of `complex` means `int | float | complex`, so `int` and `float` are both assignable
to it (but not the other way around):
```py
def takes_complex(x: complex):
pass
def passes_to_complex(x: float, y: int):
# no errors!
takes_complex(x)
takes_complex(y)
def assigns_to_complex(x: float, y: int):
# no errors!
a: complex = x
b: complex = y
def takes_int(x: int):
pass
def takes_float(x: float):
pass
def passes_complex(x: complex):
# error: [invalid-argument-type]
takes_int(x)
# error: [invalid-argument-type]
takes_float(x)
def assigns_complex(x: complex):
# error: [invalid-assignment]
y: int = x
# error: [invalid-assignment]
z: float = x
def f(x: complex):
reveal_type(x) # revealed: int | float | complex
```

View File

@ -9,9 +9,9 @@ from typing import Union
a: Union[int, str] a: Union[int, str]
a1: Union[int, bool] a1: Union[int, bool]
a2: Union[int, Union[float, str]] a2: Union[int, Union[bytes, str]]
a3: Union[int, None] a3: Union[int, None]
a4: Union[Union[float, str]] a4: Union[Union[bytes, str]]
a5: Union[int] a5: Union[int]
a6: Union[()] a6: Union[()]
@ -21,11 +21,11 @@ def f():
# Since bool is a subtype of int we simplify to int here. But we do allow assigning boolean values (see below). # Since bool is a subtype of int we simplify to int here. But we do allow assigning boolean values (see below).
# revealed: int # revealed: int
reveal_type(a1) reveal_type(a1)
# revealed: int | float | str # revealed: int | bytes | str
reveal_type(a2) reveal_type(a2)
# revealed: int | None # revealed: int | None
reveal_type(a3) reveal_type(a3)
# revealed: float | str # revealed: bytes | str
reveal_type(a4) reveal_type(a4)
# revealed: int # revealed: int
reveal_type(a5) reveal_type(a5)

View File

@ -9,7 +9,7 @@ reveal_type(x) # revealed: Literal[2]
x = 1.0 x = 1.0
x /= 2 x /= 2
reveal_type(x) # revealed: float reveal_type(x) # revealed: int | float
``` ```
## Dunder methods ## Dunder methods
@ -24,12 +24,12 @@ x -= 1
reveal_type(x) # revealed: str reveal_type(x) # revealed: str
class C: class C:
def __iadd__(self, other: str) -> float: def __iadd__(self, other: str) -> int:
return 1.0 return 1
x = C() x = C()
x += "Hello" x += "Hello"
reveal_type(x) # revealed: float reveal_type(x) # revealed: int
``` ```
## Unsupported types ## Unsupported types
@ -130,10 +130,10 @@ def _(flag: bool):
if flag: if flag:
f = Foo() f = Foo()
else: else:
f = 42.0 f = 42
f += 12 f += 12
reveal_type(f) # revealed: str | float reveal_type(f) # revealed: str | Literal[54]
``` ```
## Partially bound target union with `__add__` ## Partially bound target union with `__add__`

View File

@ -56,7 +56,7 @@ def _(a: bool):
reveal_type(x - a) # revealed: int reveal_type(x - a) # revealed: int
reveal_type(x * a) # revealed: int reveal_type(x * a) # revealed: int
reveal_type(x // a) # revealed: int reveal_type(x // a) # revealed: int
reveal_type(x / a) # revealed: float reveal_type(x / a) # revealed: int | float
reveal_type(x % a) # revealed: int reveal_type(x % a) # revealed: int
def rhs_is_int(x: int): def rhs_is_int(x: int):
@ -64,7 +64,7 @@ def _(a: bool):
reveal_type(a - x) # revealed: int reveal_type(a - x) # revealed: int
reveal_type(a * x) # revealed: int reveal_type(a * x) # revealed: int
reveal_type(a // x) # revealed: int reveal_type(a // x) # revealed: int
reveal_type(a / x) # revealed: float reveal_type(a / x) # revealed: int | float
reveal_type(a % x) # revealed: int reveal_type(a % x) # revealed: int
def lhs_is_bool(x: bool): def lhs_is_bool(x: bool):
@ -72,7 +72,7 @@ def _(a: bool):
reveal_type(x - a) # revealed: int reveal_type(x - a) # revealed: int
reveal_type(x * a) # revealed: int reveal_type(x * a) # revealed: int
reveal_type(x // a) # revealed: int reveal_type(x // a) # revealed: int
reveal_type(x / a) # revealed: float reveal_type(x / a) # revealed: int | float
reveal_type(x % a) # revealed: int reveal_type(x % a) # revealed: int
def rhs_is_bool(x: bool): def rhs_is_bool(x: bool):
@ -80,7 +80,7 @@ def _(a: bool):
reveal_type(a - x) # revealed: int reveal_type(a - x) # revealed: int
reveal_type(a * x) # revealed: int reveal_type(a * x) # revealed: int
reveal_type(a // x) # revealed: int reveal_type(a // x) # revealed: int
reveal_type(a / x) # revealed: float reveal_type(a / x) # revealed: int | float
reveal_type(a % x) # revealed: int reveal_type(a % x) # revealed: int
def both_are_bool(x: bool, y: bool): def both_are_bool(x: bool, y: bool):
@ -88,6 +88,6 @@ def _(a: bool):
reveal_type(x - y) # revealed: int reveal_type(x - y) # revealed: int
reveal_type(x * y) # revealed: int reveal_type(x * y) # revealed: int
reveal_type(x // y) # revealed: int reveal_type(x // y) # revealed: int
reveal_type(x / y) # revealed: float reveal_type(x / y) # revealed: int | float
reveal_type(x % y) # revealed: int reveal_type(x % y) # revealed: int
``` ```

View File

@ -268,23 +268,28 @@ reveal_type(B() + B()) # revealed: Unknown | int
## Integration test: numbers from typeshed ## Integration test: numbers from typeshed
We get less precise results from binary operations on float/complex literals due to the special case
for annotations of `float` or `complex`, which applies also to return annotations for typeshed
dunder methods. Perhaps we could have a special-case on the special-case, to exclude these typeshed
return annotations from the widening, and preserve a bit more precision here?
```py ```py
reveal_type(3j + 3.14) # revealed: complex reveal_type(3j + 3.14) # revealed: int | float | complex
reveal_type(4.2 + 42) # revealed: float reveal_type(4.2 + 42) # revealed: int | float
reveal_type(3j + 3) # revealed: complex reveal_type(3j + 3) # revealed: int | float | complex
# TODO should be complex, need to check arg type and fall back to `rhs.__radd__` # TODO should be int | float | complex, need to check arg type and fall back to `rhs.__radd__`
reveal_type(3.14 + 3j) # revealed: float reveal_type(3.14 + 3j) # revealed: int | float
# TODO should be float, need to check arg type and fall back to `rhs.__radd__` # TODO should be int | float, need to check arg type and fall back to `rhs.__radd__`
reveal_type(42 + 4.2) # revealed: int reveal_type(42 + 4.2) # revealed: int
# TODO should be complex, need to check arg type and fall back to `rhs.__radd__` # TODO should be int | float | complex, need to check arg type and fall back to `rhs.__radd__`
reveal_type(3 + 3j) # revealed: int reveal_type(3 + 3j) # revealed: int
def _(x: bool, y: int): def _(x: bool, y: int):
reveal_type(x + y) # revealed: int reveal_type(x + y) # revealed: int
reveal_type(4.2 + x) # revealed: float reveal_type(4.2 + x) # revealed: int | float
# TODO should be float, need to check arg type and fall back to `rhs.__radd__` # TODO should be float, need to check arg type and fall back to `rhs.__radd__`
reveal_type(y + 4.12) # revealed: int reveal_type(y + 4.12) # revealed: int

View File

@ -19,7 +19,7 @@ def lhs(x: int):
reveal_type(x - 4) # revealed: int reveal_type(x - 4) # revealed: int
reveal_type(x * -1) # revealed: int reveal_type(x * -1) # revealed: int
reveal_type(x // 3) # revealed: int reveal_type(x // 3) # revealed: int
reveal_type(x / 3) # revealed: float reveal_type(x / 3) # revealed: int | float
reveal_type(x % 3) # revealed: int reveal_type(x % 3) # revealed: int
def rhs(x: int): def rhs(x: int):
@ -27,7 +27,7 @@ def rhs(x: int):
reveal_type(3 - x) # revealed: int reveal_type(3 - x) # revealed: int
reveal_type(3 * x) # revealed: int reveal_type(3 * x) # revealed: int
reveal_type(-3 // x) # revealed: int reveal_type(-3 // x) # revealed: int
reveal_type(-3 / x) # revealed: float reveal_type(-3 / x) # revealed: int | float
reveal_type(5 % x) # revealed: int reveal_type(5 % x) # revealed: int
def both(x: int): def both(x: int):
@ -35,7 +35,7 @@ def both(x: int):
reveal_type(x - x) # revealed: int reveal_type(x - x) # revealed: int
reveal_type(x * x) # revealed: int reveal_type(x * x) # revealed: int
reveal_type(x // x) # revealed: int reveal_type(x // x) # revealed: int
reveal_type(x / x) # revealed: float reveal_type(x / x) # revealed: int | float
reveal_type(x % x) # revealed: int reveal_type(x % x) # revealed: int
``` ```
@ -80,24 +80,20 @@ c = 3 % 0 # error: "Cannot reduce object of type `Literal[3]` modulo zero"
reveal_type(c) # revealed: int reveal_type(c) # revealed: int
# error: "Cannot divide object of type `int` by zero" # error: "Cannot divide object of type `int` by zero"
# revealed: float reveal_type(int() / 0) # revealed: int | float
reveal_type(int() / 0)
# error: "Cannot divide object of type `Literal[1]` by zero" # error: "Cannot divide object of type `Literal[1]` by zero"
# revealed: float reveal_type(1 / False) # revealed: float
reveal_type(1 / False)
# error: [division-by-zero] "Cannot divide object of type `Literal[True]` by zero" # error: [division-by-zero] "Cannot divide object of type `Literal[True]` by zero"
True / False True / False
# error: [division-by-zero] "Cannot divide object of type `Literal[True]` by zero" # error: [division-by-zero] "Cannot divide object of type `Literal[True]` by zero"
bool(1) / False bool(1) / False
# error: "Cannot divide object of type `float` by zero" # error: "Cannot divide object of type `float` by zero"
# revealed: float reveal_type(1.0 / 0) # revealed: int | float
reveal_type(1.0 / 0)
class MyInt(int): ... class MyInt(int): ...
# No error for a subclass of int # No error for a subclass of int
# revealed: float reveal_type(MyInt(3) / 0) # revealed: int | float
reveal_type(MyInt(3) / 0)
``` ```

View File

@ -4,14 +4,14 @@
```py ```py
class Multiplier: class Multiplier:
def __init__(self, factor: float): def __init__(self, factor: int):
self.factor = factor self.factor = factor
def __call__(self, number: float) -> float: def __call__(self, number: int) -> int:
return number * self.factor return number * self.factor
a = Multiplier(2.0)(3.0) a = Multiplier(2)(3)
reveal_type(a) # revealed: float reveal_type(a) # revealed: int
class Unit: ... class Unit: ...

View File

@ -20,8 +20,8 @@ class A:
def __eq__(self, other: A) -> int: def __eq__(self, other: A) -> int:
return 42 return 42
def __ne__(self, other: A) -> float: def __ne__(self, other: A) -> bytearray:
return 42.0 return bytearray()
def __lt__(self, other: A) -> str: def __lt__(self, other: A) -> str:
return "42" return "42"
@ -36,7 +36,7 @@ class A:
return {42} return {42}
reveal_type(A() == A()) # revealed: int reveal_type(A() == A()) # revealed: int
reveal_type(A() != A()) # revealed: float reveal_type(A() != A()) # revealed: bytearray
reveal_type(A() < A()) # revealed: str reveal_type(A() < A()) # revealed: str
reveal_type(A() <= A()) # revealed: bytes reveal_type(A() <= A()) # revealed: bytes
reveal_type(A() > A()) # revealed: list reveal_type(A() > A()) # revealed: list
@ -55,8 +55,8 @@ class A:
def __eq__(self, other: B) -> int: def __eq__(self, other: B) -> int:
return 42 return 42
def __ne__(self, other: B) -> float: def __ne__(self, other: B) -> bytearray:
return 42.0 return bytearray()
def __lt__(self, other: B) -> str: def __lt__(self, other: B) -> str:
return "42" return "42"
@ -73,7 +73,7 @@ class A:
class B: ... class B: ...
reveal_type(A() == B()) # revealed: int reveal_type(A() == B()) # revealed: int
reveal_type(A() != B()) # revealed: float reveal_type(A() != B()) # revealed: bytearray
reveal_type(A() < B()) # revealed: str reveal_type(A() < B()) # revealed: str
reveal_type(A() <= B()) # revealed: bytes reveal_type(A() <= B()) # revealed: bytes
reveal_type(A() > B()) # revealed: list reveal_type(A() > B()) # revealed: list
@ -93,8 +93,8 @@ class A:
def __eq__(self, other: B) -> int: def __eq__(self, other: B) -> int:
return 42 return 42
def __ne__(self, other: B) -> float: def __ne__(self, other: B) -> bytearray:
return 42.0 return bytearray()
def __lt__(self, other: B) -> str: def __lt__(self, other: B) -> str:
return "42" return "42"
@ -117,7 +117,7 @@ class B:
def __ne__(self, other: str) -> B: def __ne__(self, other: str) -> B:
return B() return B()
# TODO: should be `int` and `float`. # TODO: should be `int` and `bytearray`.
# Need to check arg type and fall back to `rhs.__eq__` and `rhs.__ne__`. # Need to check arg type and fall back to `rhs.__eq__` and `rhs.__ne__`.
# #
# Because `object.__eq__` and `object.__ne__` accept `object` in typeshed, # Because `object.__eq__` and `object.__ne__` accept `object` in typeshed,
@ -136,11 +136,11 @@ class C:
def __gt__(self, other: C) -> int: def __gt__(self, other: C) -> int:
return 42 return 42
def __ge__(self, other: C) -> float: def __ge__(self, other: C) -> bytearray:
return 42.0 return bytearray()
reveal_type(C() < C()) # revealed: int reveal_type(C() < C()) # revealed: int
reveal_type(C() <= C()) # revealed: float reveal_type(C() <= C()) # revealed: bytearray
``` ```
## Reflected Comparisons with Subclasses ## Reflected Comparisons with Subclasses
@ -175,8 +175,8 @@ class B(A):
def __eq__(self, other: A) -> int: def __eq__(self, other: A) -> int:
return 42 return 42
def __ne__(self, other: A) -> float: def __ne__(self, other: A) -> bytearray:
return 42.0 return bytearray()
def __lt__(self, other: A) -> str: def __lt__(self, other: A) -> str:
return "42" return "42"
@ -191,7 +191,7 @@ class B(A):
return {42} return {42}
reveal_type(A() == B()) # revealed: int reveal_type(A() == B()) # revealed: int
reveal_type(A() != B()) # revealed: float reveal_type(A() != B()) # revealed: bytearray
reveal_type(A() < B()) # revealed: list reveal_type(A() < B()) # revealed: list
reveal_type(A() <= B()) # revealed: set reveal_type(A() <= B()) # revealed: set

View File

@ -151,11 +151,11 @@ class A:
def __ne__(self, o: object) -> bytes: def __ne__(self, o: object) -> bytes:
return b"world" return b"world"
def __lt__(self, o: A) -> float: def __lt__(self, o: A) -> bytearray:
return 3.14 return bytearray()
def __le__(self, o: A) -> complex: def __le__(self, o: A) -> memoryview:
return complex(0.5, -0.5) return memoryview(b"")
def __gt__(self, o: A) -> tuple: def __gt__(self, o: A) -> tuple:
return (1, 2, 3) return (1, 2, 3)
@ -167,8 +167,8 @@ a = (A(), A())
reveal_type(a == a) # revealed: bool reveal_type(a == a) # revealed: bool
reveal_type(a != a) # revealed: bool reveal_type(a != a) # revealed: bool
reveal_type(a < a) # revealed: float | Literal[False] reveal_type(a < a) # revealed: bytearray | Literal[False]
reveal_type(a <= a) # revealed: complex | Literal[True] reveal_type(a <= a) # revealed: memoryview | Literal[True]
reveal_type(a > a) # revealed: tuple | Literal[False] reveal_type(a > a) # revealed: tuple | Literal[False]
reveal_type(a >= a) # revealed: list | Literal[True] reveal_type(a >= a) # revealed: list | Literal[True]
@ -187,7 +187,7 @@ class B:
def __lt__(self, o: B) -> set: def __lt__(self, o: B) -> set:
return set() return set()
reveal_type((A(), B()) < (A(), B())) # revealed: float | set | Literal[False] reveal_type((A(), B()) < (A(), B())) # revealed: bytearray | set | Literal[False]
``` ```
#### Special Handling of Eq and NotEq in Lexicographic Comparisons #### Special Handling of Eq and NotEq in Lexicographic Comparisons

View File

@ -303,8 +303,8 @@ An example with multiple `except` branches and a `finally` branch:
def could_raise_returns_memoryview() -> memoryview: def could_raise_returns_memoryview() -> memoryview:
return memoryview(b"") return memoryview(b"")
def could_raise_returns_float() -> float: def could_raise_returns_bytearray() -> bytearray:
return 3.14 return bytearray()
x = 1 x = 1
@ -322,13 +322,13 @@ except ValueError:
reveal_type(x) # revealed: Literal[1] | str reveal_type(x) # revealed: Literal[1] | str
x = could_raise_returns_memoryview() x = could_raise_returns_memoryview()
reveal_type(x) # revealed: memoryview reveal_type(x) # revealed: memoryview
x = could_raise_returns_float() x = could_raise_returns_bytearray()
reveal_type(x) # revealed: float reveal_type(x) # revealed: bytearray
finally: finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float` # TODO: should be `Literal[1] | str | bytes | bool | memoryview | bytearray`
reveal_type(x) # revealed: str | bool | float reveal_type(x) # revealed: str | bool | bytearray
reveal_type(x) # revealed: str | bool | float reveal_type(x) # revealed: str | bool | bytearray
``` ```
## Combining `except`, `else` and `finally` branches ## Combining `except`, `else` and `finally` branches
@ -350,8 +350,8 @@ def could_raise_returns_bool() -> bool:
def could_raise_returns_memoryview() -> memoryview: def could_raise_returns_memoryview() -> memoryview:
return memoryview(b"") return memoryview(b"")
def could_raise_returns_float() -> float: def could_raise_returns_bytearray() -> bytearray:
return 3.14 return bytearray()
x = 1 x = 1
@ -369,13 +369,13 @@ else:
reveal_type(x) # revealed: str reveal_type(x) # revealed: str
x = could_raise_returns_memoryview() x = could_raise_returns_memoryview()
reveal_type(x) # revealed: memoryview reveal_type(x) # revealed: memoryview
x = could_raise_returns_float() x = could_raise_returns_bytearray()
reveal_type(x) # revealed: float reveal_type(x) # revealed: bytearray
finally: finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float` # TODO: should be `Literal[1] | str | bytes | bool | memoryview | bytearray`
reveal_type(x) # revealed: bool | float reveal_type(x) # revealed: bool | bytearray
reveal_type(x) # revealed: bool | float reveal_type(x) # revealed: bool | bytearray
``` ```
The same again, this time with multiple `except` branches: The same again, this time with multiple `except` branches:
@ -403,8 +403,8 @@ except ValueError:
reveal_type(x) # revealed: Literal[1] | str reveal_type(x) # revealed: Literal[1] | str
x = could_raise_returns_memoryview() x = could_raise_returns_memoryview()
reveal_type(x) # revealed: memoryview reveal_type(x) # revealed: memoryview
x = could_raise_returns_float() x = could_raise_returns_bytearray()
reveal_type(x) # revealed: float reveal_type(x) # revealed: bytearray
else: else:
reveal_type(x) # revealed: str reveal_type(x) # revealed: str
x = could_raise_returns_range() x = could_raise_returns_range()
@ -412,10 +412,10 @@ else:
x = could_raise_returns_slice() x = could_raise_returns_slice()
reveal_type(x) # revealed: slice reveal_type(x) # revealed: slice
finally: finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` # TODO: should be `Literal[1] | str | bytes | bool | memoryview | bytearray | range | slice`
reveal_type(x) # revealed: bool | float | slice reveal_type(x) # revealed: bool | bytearray | slice
reveal_type(x) # revealed: bool | float | slice reveal_type(x) # revealed: bool | bytearray | slice
``` ```
## Nested `try`/`except` blocks ## Nested `try`/`except` blocks
@ -441,8 +441,8 @@ def could_raise_returns_bool() -> bool:
def could_raise_returns_memoryview() -> memoryview: def could_raise_returns_memoryview() -> memoryview:
return memoryview(b"") return memoryview(b"")
def could_raise_returns_float() -> float: def could_raise_returns_property() -> property:
return 3.14 return property()
def could_raise_returns_range() -> range: def could_raise_returns_range() -> range:
return range(42) return range(42)
@ -450,8 +450,8 @@ def could_raise_returns_range() -> range:
def could_raise_returns_slice() -> slice: def could_raise_returns_slice() -> slice:
return slice(None) return slice(None)
def could_raise_returns_complex() -> complex: def could_raise_returns_super() -> super:
return 3j return super()
def could_raise_returns_bytearray() -> bytearray: def could_raise_returns_bytearray() -> bytearray:
return bytearray() return bytearray()
@ -482,8 +482,8 @@ try:
reveal_type(x) # revealed: Literal[1] | str reveal_type(x) # revealed: Literal[1] | str
x = could_raise_returns_memoryview() x = could_raise_returns_memoryview()
reveal_type(x) # revealed: memoryview reveal_type(x) # revealed: memoryview
x = could_raise_returns_float() x = could_raise_returns_property()
reveal_type(x) # revealed: float reveal_type(x) # revealed: property
else: else:
reveal_type(x) # revealed: str reveal_type(x) # revealed: str
x = could_raise_returns_range() x = could_raise_returns_range()
@ -491,15 +491,15 @@ try:
x = could_raise_returns_slice() x = could_raise_returns_slice()
reveal_type(x) # revealed: slice reveal_type(x) # revealed: slice
finally: finally:
# TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` # TODO: should be `Literal[1] | str | bytes | bool | memoryview | property | range | slice`
reveal_type(x) # revealed: bool | float | slice reveal_type(x) # revealed: bool | property | slice
x = 2 x = 2
reveal_type(x) # revealed: Literal[2] reveal_type(x) # revealed: Literal[2]
reveal_type(x) # revealed: Literal[2] reveal_type(x) # revealed: Literal[2]
except: except:
reveal_type(x) # revealed: Literal[1, 2] | str | bytes | bool | memoryview | float | range | slice reveal_type(x) # revealed: Literal[1, 2] | str | bytes | bool | memoryview | property | range | slice
x = could_raise_returns_complex() x = could_raise_returns_super()
reveal_type(x) # revealed: complex reveal_type(x) # revealed: super
x = could_raise_returns_bytearray() x = could_raise_returns_bytearray()
reveal_type(x) # revealed: bytearray reveal_type(x) # revealed: bytearray
else: else:
@ -509,7 +509,7 @@ else:
x = could_raise_returns_Bar() x = could_raise_returns_Bar()
reveal_type(x) # revealed: Bar reveal_type(x) # revealed: Bar
finally: finally:
# TODO: should be `Literal[1, 2] | str | bytes | bool | memoryview | float | range | slice | complex | bytearray | Foo | Bar` # TODO: should be `Literal[1, 2] | str | bytes | bool | memoryview | property | range | slice | super | bytearray | Foo | Bar`
reveal_type(x) # revealed: bytearray | Bar reveal_type(x) # revealed: bytearray | Bar
# Either one `except` branch or the `else` # Either one `except` branch or the `else`
@ -535,8 +535,8 @@ def could_raise_returns_range() -> range:
def could_raise_returns_bytearray() -> bytearray: def could_raise_returns_bytearray() -> bytearray:
return bytearray() return bytearray()
def could_raise_returns_float() -> float: def could_raise_returns_memoryview() -> memoryview:
return 3.14 return memoryview(b"")
x = 1 x = 1
@ -553,12 +553,12 @@ try:
reveal_type(x) # revealed: str | bytes reveal_type(x) # revealed: str | bytes
x = could_raise_returns_bytearray() x = could_raise_returns_bytearray()
reveal_type(x) # revealed: bytearray reveal_type(x) # revealed: bytearray
x = could_raise_returns_float() x = could_raise_returns_memoryview()
reveal_type(x) # revealed: float reveal_type(x) # revealed: memoryview
finally: finally:
# TODO: should be `str | bytes | bytearray | float` # TODO: should be `str | bytes | bytearray | memoryview`
reveal_type(x) # revealed: bytes | float reveal_type(x) # revealed: bytes | memoryview
reveal_type(x) # revealed: bytes | float reveal_type(x) # revealed: bytes | memoryview
x = foo x = foo
reveal_type(x) # revealed: Literal[foo] reveal_type(x) # revealed: Literal[foo]
except: except:

View File

@ -11,11 +11,15 @@ See the [typing documentation] for more information.
- `bool` is a subtype of `int`. This is modeled after Python's runtime behavior, where `int` is a - `bool` is a subtype of `int`. This is modeled after Python's runtime behavior, where `int` is a
supertype of `bool` (present in `bool`s bases and MRO). supertype of `bool` (present in `bool`s bases and MRO).
- `int` is not a subtype of `float`/`complex`, even though `float`/`complex` can be used in place of - `int` is not a subtype of `float`/`complex`, although this is muddied by the
`int` in some contexts (see [special case for float and complex]). [special case for float and complex] where annotations of `float` and `complex` are interpreted
as `int | float` and `int | float | complex`, respectively.
```py ```py
from knot_extensions import is_subtype_of, static_assert from knot_extensions import is_subtype_of, static_assert, TypeOf
type JustFloat = TypeOf[1.0]
type JustComplex = TypeOf[1j]
static_assert(is_subtype_of(bool, bool)) static_assert(is_subtype_of(bool, bool))
static_assert(is_subtype_of(bool, int)) static_assert(is_subtype_of(bool, int))
@ -30,8 +34,8 @@ static_assert(not is_subtype_of(int, bool))
static_assert(not is_subtype_of(int, str)) static_assert(not is_subtype_of(int, str))
static_assert(not is_subtype_of(object, int)) static_assert(not is_subtype_of(object, int))
static_assert(not is_subtype_of(int, float)) static_assert(not is_subtype_of(int, JustFloat))
static_assert(not is_subtype_of(int, complex)) static_assert(not is_subtype_of(int, JustComplex))
static_assert(is_subtype_of(TypeError, Exception)) static_assert(is_subtype_of(TypeError, Exception))
static_assert(is_subtype_of(FloatingPointError, Exception)) static_assert(is_subtype_of(FloatingPointError, Exception))
@ -79,7 +83,9 @@ static_assert(is_subtype_of(C, object))
```py ```py
from typing_extensions import Literal, LiteralString from typing_extensions import Literal, LiteralString
from knot_extensions import is_subtype_of, static_assert from knot_extensions import is_subtype_of, static_assert, TypeOf
type JustFloat = TypeOf[1.0]
# Boolean literals # Boolean literals
static_assert(is_subtype_of(Literal[True], bool)) static_assert(is_subtype_of(Literal[True], bool))
@ -92,8 +98,7 @@ static_assert(is_subtype_of(Literal[1], object))
static_assert(not is_subtype_of(Literal[1], bool)) static_assert(not is_subtype_of(Literal[1], bool))
# See the note above (or link below) concerning int and float/complex static_assert(not is_subtype_of(Literal[1], JustFloat))
static_assert(not is_subtype_of(Literal[1], float))
# String literals # String literals
static_assert(is_subtype_of(Literal["foo"], LiteralString)) static_assert(is_subtype_of(Literal["foo"], LiteralString))

View File

@ -70,11 +70,11 @@ from typing import Literal
def _( def _(
u1: (int | str) | bytes, u1: (int | str) | bytes,
u2: int | (str | bytes), u2: int | (str | bytes),
u3: int | (str | (bytes | complex)), u3: int | (str | (bytes | bytearray)),
) -> None: ) -> None:
reveal_type(u1) # revealed: int | str | bytes reveal_type(u1) # revealed: int | str | bytes
reveal_type(u2) # revealed: int | str | bytes reveal_type(u2) # revealed: int | str | bytes
reveal_type(u3) # revealed: int | str | bytes | complex reveal_type(u3) # revealed: int | str | bytes | bytearray
``` ```
## Simplification using subtyping ## Simplification using subtyping

View File

@ -1765,6 +1765,7 @@ impl<'db> Type<'db> {
| KnownClass::Type | KnownClass::Type
| KnownClass::Int | KnownClass::Int
| KnownClass::Float | KnownClass::Float
| KnownClass::Complex
| KnownClass::Str | KnownClass::Str
| KnownClass::List | KnownClass::List
| KnownClass::Tuple | KnownClass::Tuple
@ -2433,6 +2434,31 @@ impl<'db> Type<'db> {
db: &'db dyn Db, db: &'db dyn Db,
) -> Result<Type<'db>, InvalidTypeExpressionError<'db>> { ) -> Result<Type<'db>, InvalidTypeExpressionError<'db>> {
match self { match self {
// Special cases for `float` and `complex`
// https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex
Type::ClassLiteral(ClassLiteralType { class })
if class.is_known(db, KnownClass::Float) =>
{
Ok(UnionType::from_elements(
db,
[
KnownClass::Int.to_instance(db),
KnownClass::Float.to_instance(db),
],
))
}
Type::ClassLiteral(ClassLiteralType { class })
if class.is_known(db, KnownClass::Complex) =>
{
Ok(UnionType::from_elements(
db,
[
KnownClass::Int.to_instance(db),
KnownClass::Float.to_instance(db),
KnownClass::Complex.to_instance(db),
],
))
}
// In a type expression, a bare `type` is interpreted as "instance of `type`", which is // In a type expression, a bare `type` is interpreted as "instance of `type`", which is
// equivalent to `type[object]`. // equivalent to `type[object]`.
Type::ClassLiteral(_) | Type::SubclassOf(_) => Ok(self.to_instance(db)), Type::ClassLiteral(_) | Type::SubclassOf(_) => Ok(self.to_instance(db)),
@ -2808,6 +2834,7 @@ pub enum KnownClass {
Type, Type,
Int, Int,
Float, Float,
Complex,
Str, Str,
List, List,
Tuple, Tuple,
@ -2853,6 +2880,7 @@ impl<'db> KnownClass {
Self::Tuple => "tuple", Self::Tuple => "tuple",
Self::Int => "int", Self::Int => "int",
Self::Float => "float", Self::Float => "float",
Self::Complex => "complex",
Self::FrozenSet => "frozenset", Self::FrozenSet => "frozenset",
Self::Str => "str", Self::Str => "str",
Self::Set => "set", Self::Set => "set",
@ -2922,6 +2950,7 @@ impl<'db> KnownClass {
| Self::Type | Self::Type
| Self::Int | Self::Int
| Self::Float | Self::Float
| Self::Complex
| Self::Str | Self::Str
| Self::List | Self::List
| Self::Tuple | Self::Tuple
@ -2971,6 +3000,7 @@ impl<'db> KnownClass {
| Self::Tuple | Self::Tuple
| Self::Int | Self::Int
| Self::Float | Self::Float
| Self::Complex
| Self::Str | Self::Str
| Self::Set | Self::Set
| Self::FrozenSet | Self::FrozenSet
@ -3007,6 +3037,7 @@ impl<'db> KnownClass {
"type" => Self::Type, "type" => Self::Type,
"int" => Self::Int, "int" => Self::Int,
"float" => Self::Float, "float" => Self::Float,
"complex" => Self::Complex,
"str" => Self::Str, "str" => Self::Str,
"set" => Self::Set, "set" => Self::Set,
"frozenset" => Self::FrozenSet, "frozenset" => Self::FrozenSet,
@ -3046,6 +3077,7 @@ impl<'db> KnownClass {
| Self::Type | Self::Type
| Self::Int | Self::Int
| Self::Float | Self::Float
| Self::Complex
| Self::Str | Self::Str
| Self::List | Self::List
| Self::Tuple | Self::Tuple