mirror of https://github.com/astral-sh/ruff
[ty] Improve literal promotion heuristics (#21439)
## Summary Extends literal promotion to apply to any generic method, as opposed to only generic class constructors. This PR also improves our literal promotion heuristics to only promote literals in non-covariant position in the return type, and avoid promotion if the literal is present in non-covariant position in any argument type. Resolves https://github.com/astral-sh/ty/issues/1357.
This commit is contained in:
parent
3e7e91724c
commit
c5d654bce8
|
|
@ -229,87 +229,6 @@ a: list[str] = [1, 2, 3]
|
|||
b: set[int] = {1, 2, "3"}
|
||||
```
|
||||
|
||||
## Literal annnotations are respected
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
```py
|
||||
from enum import Enum
|
||||
from typing_extensions import Literal, LiteralString
|
||||
|
||||
a: list[Literal[1]] = [1]
|
||||
reveal_type(a) # revealed: list[Literal[1]]
|
||||
|
||||
b: list[Literal[True]] = [True]
|
||||
reveal_type(b) # revealed: list[Literal[True]]
|
||||
|
||||
c: list[Literal["a"]] = ["a"]
|
||||
reveal_type(c) # revealed: list[Literal["a"]]
|
||||
|
||||
d: list[LiteralString] = ["a", "b", "c"]
|
||||
reveal_type(d) # revealed: list[LiteralString]
|
||||
|
||||
e: list[list[Literal[1]]] = [[1]]
|
||||
reveal_type(e) # revealed: list[list[Literal[1]]]
|
||||
|
||||
class Color(Enum):
|
||||
RED = "red"
|
||||
|
||||
f: dict[list[Literal[1]], list[Literal[Color.RED]]] = {[1]: [Color.RED, Color.RED]}
|
||||
reveal_type(f) # revealed: dict[list[Literal[1]], list[Color]]
|
||||
|
||||
class X[T]:
|
||||
def __init__(self, value: T): ...
|
||||
|
||||
g: X[Literal[1]] = X(1)
|
||||
reveal_type(g) # revealed: X[Literal[1]]
|
||||
|
||||
h: X[int] = X(1)
|
||||
reveal_type(h) # revealed: X[int]
|
||||
|
||||
i: dict[list[X[Literal[1]]], set[Literal[b"a"]]] = {[X(1)]: {b"a"}}
|
||||
reveal_type(i) # revealed: dict[list[X[Literal[1]]], set[Literal[b"a"]]]
|
||||
|
||||
j: list[Literal[1, 2, 3]] = [1, 2, 3]
|
||||
reveal_type(j) # revealed: list[Literal[1, 2, 3]]
|
||||
|
||||
k: list[Literal[1] | Literal[2] | Literal[3]] = [1, 2, 3]
|
||||
reveal_type(k) # revealed: list[Literal[1, 2, 3]]
|
||||
|
||||
type Y[T] = list[T]
|
||||
|
||||
l: Y[Y[Literal[1]]] = [[1]]
|
||||
reveal_type(l) # revealed: list[Y[Literal[1]]]
|
||||
|
||||
m: list[tuple[Literal[1], Literal[2], Literal[3]]] = [(1, 2, 3)]
|
||||
reveal_type(m) # revealed: list[tuple[Literal[1], Literal[2], Literal[3]]]
|
||||
|
||||
n: list[tuple[int, str, int]] = [(1, "2", 3), (4, "5", 6)]
|
||||
reveal_type(n) # revealed: list[tuple[int, str, int]]
|
||||
|
||||
o: list[tuple[Literal[1], ...]] = [(1, 1, 1)]
|
||||
reveal_type(o) # revealed: list[tuple[Literal[1], ...]]
|
||||
|
||||
p: list[tuple[int, ...]] = [(1, 1, 1)]
|
||||
reveal_type(p) # revealed: list[tuple[int, ...]]
|
||||
|
||||
# literal promotion occurs based on assignability, an exact match is not required
|
||||
q: list[int | Literal[1]] = [1]
|
||||
reveal_type(q) # revealed: list[int]
|
||||
|
||||
r: list[Literal[1, 2, 3, 4]] = [1, 2]
|
||||
reveal_type(r) # revealed: list[Literal[1, 2, 3, 4]]
|
||||
|
||||
s: list[Literal[1]]
|
||||
s = [1]
|
||||
reveal_type(s) # revealed: list[Literal[1]]
|
||||
(s := [1])
|
||||
reveal_type(s) # revealed: list[Literal[1]]
|
||||
```
|
||||
|
||||
## Generic constructor annotations are understood
|
||||
|
||||
```toml
|
||||
|
|
@ -352,17 +271,25 @@ from dataclasses import dataclass
|
|||
class Y[T]:
|
||||
value: T
|
||||
|
||||
y1: Y[Any] = Y(value=1)
|
||||
reveal_type(y1) # revealed: Y[Any]
|
||||
y1 = Y(value=1)
|
||||
reveal_type(y1) # revealed: Y[int]
|
||||
|
||||
y2: Y[Any] = Y(value=1)
|
||||
reveal_type(y2) # revealed: Y[Any]
|
||||
```
|
||||
|
||||
```py
|
||||
class Z[T]:
|
||||
value: T
|
||||
|
||||
def __new__(cls, value: T):
|
||||
return super().__new__(cls)
|
||||
|
||||
z1: Z[Any] = Z(1)
|
||||
reveal_type(z1) # revealed: Z[Any]
|
||||
z1 = Z(1)
|
||||
reveal_type(z1) # revealed: Z[int]
|
||||
|
||||
z2: Z[Any] = Z(1)
|
||||
reveal_type(z2) # revealed: Z[Any]
|
||||
```
|
||||
|
||||
## PEP-604 annotations are supported
|
||||
|
|
@ -481,7 +408,7 @@ def f[T](x: T) -> list[T]:
|
|||
return [x]
|
||||
|
||||
a = f("a")
|
||||
reveal_type(a) # revealed: list[Literal["a"]]
|
||||
reveal_type(a) # revealed: list[str]
|
||||
|
||||
b: list[int | Literal["a"]] = f("a")
|
||||
reveal_type(b) # revealed: list[int | Literal["a"]]
|
||||
|
|
@ -495,10 +422,10 @@ reveal_type(d) # revealed: list[int | tuple[int, int]]
|
|||
e: list[int] = f(True)
|
||||
reveal_type(e) # revealed: list[int]
|
||||
|
||||
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `list[int]`"
|
||||
# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `list[int]`"
|
||||
g: list[int] = f("a")
|
||||
|
||||
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
|
||||
# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `tuple[int]`"
|
||||
h: tuple[int] = f("a")
|
||||
|
||||
def f2[T: int](x: T) -> T:
|
||||
|
|
@ -603,7 +530,7 @@ def f3[T](x: T) -> list[T] | dict[T, T]:
|
|||
return [x]
|
||||
|
||||
a = f(1)
|
||||
reveal_type(a) # revealed: list[Literal[1]]
|
||||
reveal_type(a) # revealed: list[int]
|
||||
|
||||
b: list[Any] = f(1)
|
||||
reveal_type(b) # revealed: list[Any]
|
||||
|
|
@ -619,11 +546,11 @@ reveal_type(e) # revealed: list[Any]
|
|||
|
||||
f: list[Any] | None = f2(1)
|
||||
# TODO: Better constraint solver.
|
||||
reveal_type(f) # revealed: list[Literal[1]] | None
|
||||
reveal_type(f) # revealed: list[int] | None
|
||||
|
||||
g: list[Any] | dict[Any, Any] = f3(1)
|
||||
# TODO: Better constraint solver.
|
||||
reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]]
|
||||
reveal_type(g) # revealed: list[int] | dict[int, int]
|
||||
```
|
||||
|
||||
We currently prefer the generic declared type regardless of its variance:
|
||||
|
|
@ -662,8 +589,8 @@ x4 = invariant(1)
|
|||
|
||||
reveal_type(x1) # revealed: Bivariant[Literal[1]]
|
||||
reveal_type(x2) # revealed: Covariant[Literal[1]]
|
||||
reveal_type(x3) # revealed: Contravariant[Literal[1]]
|
||||
reveal_type(x4) # revealed: Invariant[Literal[1]]
|
||||
reveal_type(x3) # revealed: Contravariant[int]
|
||||
reveal_type(x4) # revealed: Invariant[int]
|
||||
|
||||
x5: Bivariant[Any] = bivariant(1)
|
||||
x6: Covariant[Any] = covariant(1)
|
||||
|
|
|
|||
|
|
@ -16,28 +16,19 @@ python-version = "3.12"
|
|||
```
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def list1[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
l1 = list1(1)
|
||||
l1: list[Literal[1]] = list1(1)
|
||||
reveal_type(l1) # revealed: list[Literal[1]]
|
||||
l2: list[int] = list1(1)
|
||||
|
||||
l2 = list1(1)
|
||||
reveal_type(l2) # revealed: list[int]
|
||||
|
||||
# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
|
||||
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
|
||||
l2 = l1
|
||||
|
||||
intermediate = list1(1)
|
||||
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
|
||||
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
|
||||
l3: list[int] = intermediate
|
||||
# TODO: it would be nice if this were `list[int]`
|
||||
reveal_type(intermediate) # revealed: list[Literal[1]]
|
||||
reveal_type(l3) # revealed: list[int]
|
||||
|
||||
l4: list[int | str] | None = list1(1)
|
||||
reveal_type(l4) # revealed: list[int | str]
|
||||
l3: list[int | str] | None = list1(1)
|
||||
reveal_type(l3) # revealed: list[int | str]
|
||||
|
||||
def _(l: list[int] | None = None):
|
||||
l1 = l or list()
|
||||
|
|
@ -50,8 +41,6 @@ def _(l: list[int] | None = None):
|
|||
def f[T](x: T, cond: bool) -> T | list[T]:
|
||||
return x if cond else [x]
|
||||
|
||||
# TODO: Better constraint solver.
|
||||
# error: [invalid-assignment]
|
||||
l5: int | list[int] = f(1, True)
|
||||
```
|
||||
|
||||
|
|
@ -233,6 +222,9 @@ def _(flag: bool):
|
|||
|
||||
def _(c: C):
|
||||
c.x = lst(1)
|
||||
|
||||
# TODO: Use the parameter type of `__set__` as type context to avoid this error.
|
||||
# error: [invalid-assignment]
|
||||
C.x = lst(1)
|
||||
```
|
||||
|
||||
|
|
@ -296,7 +288,7 @@ def f[T](x: T) -> list[T]:
|
|||
|
||||
def _(flag: bool):
|
||||
x1 = f(1) if flag else f(2)
|
||||
reveal_type(x1) # revealed: list[Literal[1]] | list[Literal[2]]
|
||||
reveal_type(x1) # revealed: list[int]
|
||||
|
||||
x2: list[int | None] = f(1) if flag else f(2)
|
||||
reveal_type(x2) # revealed: list[int | None]
|
||||
|
|
|
|||
|
|
@ -978,10 +978,10 @@ reveal_type(ParentDataclass.__init__)
|
|||
reveal_type(ChildOfParentDataclass.__init__)
|
||||
|
||||
result_int = uses_dataclass(42)
|
||||
reveal_type(result_int) # revealed: ChildOfParentDataclass[Literal[42]]
|
||||
reveal_type(result_int) # revealed: ChildOfParentDataclass[int]
|
||||
|
||||
result_str = uses_dataclass("hello")
|
||||
reveal_type(result_str) # revealed: ChildOfParentDataclass[Literal["hello"]]
|
||||
reveal_type(result_str) # revealed: ChildOfParentDataclass[str]
|
||||
```
|
||||
|
||||
## Descriptor-typed fields
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class C[T, U]:
|
|||
class D[V](C[V, int]):
|
||||
def __init__(self, x: V) -> None: ...
|
||||
|
||||
reveal_type(D(1)) # revealed: D[int]
|
||||
reveal_type(D(1)) # revealed: D[Literal[1]]
|
||||
```
|
||||
|
||||
### Generic class inherits `__init__` from generic base class
|
||||
|
|
@ -334,8 +334,8 @@ class C[T, U]:
|
|||
class D[T, U](C[T, U]):
|
||||
pass
|
||||
|
||||
reveal_type(C(1, "str")) # revealed: C[int, str]
|
||||
reveal_type(D(1, "str")) # revealed: D[int, str]
|
||||
reveal_type(C(1, "str")) # revealed: C[Literal[1], Literal["str"]]
|
||||
reveal_type(D(1, "str")) # revealed: D[Literal[1], Literal["str"]]
|
||||
```
|
||||
|
||||
### Generic class inherits `__init__` from `dict`
|
||||
|
|
@ -358,7 +358,7 @@ context. But from the user's point of view, this is another example of the above
|
|||
```py
|
||||
class C[T, U](tuple[T, U]): ...
|
||||
|
||||
reveal_type(C((1, 2))) # revealed: C[int, int]
|
||||
reveal_type(C((1, 2))) # revealed: C[Literal[1], Literal[2]]
|
||||
```
|
||||
|
||||
### Upcasting a `tuple` to its `Sequence` supertype
|
||||
|
|
@ -442,9 +442,9 @@ class D[T, U]:
|
|||
def __init__(self, t: T, u: U) -> None: ...
|
||||
def __init__(self, *args) -> None: ...
|
||||
|
||||
reveal_type(D("string")) # revealed: D[str, str]
|
||||
reveal_type(D(1)) # revealed: D[str, int]
|
||||
reveal_type(D(1, "string")) # revealed: D[int, str]
|
||||
reveal_type(D("string")) # revealed: D[str, Literal["string"]]
|
||||
reveal_type(D(1)) # revealed: D[str, Literal[1]]
|
||||
reveal_type(D(1, "string")) # revealed: D[Literal[1], Literal["string"]]
|
||||
```
|
||||
|
||||
### Synthesized methods with dataclasses
|
||||
|
|
|
|||
|
|
@ -5,14 +5,196 @@
|
|||
python-version = "3.12"
|
||||
```
|
||||
|
||||
There are certain places where we promote literals to their common supertype:
|
||||
There are certain places where we promote literals to their common supertype.
|
||||
|
||||
## All literal types are promotable
|
||||
|
||||
```py
|
||||
from enum import Enum
|
||||
from typing import Literal, LiteralString
|
||||
|
||||
class MyEnum(Enum):
|
||||
A = 1
|
||||
|
||||
def promote[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def _(
|
||||
lit1: Literal["x"],
|
||||
lit2: LiteralString,
|
||||
lit3: Literal[True],
|
||||
lit4: Literal[b"x"],
|
||||
lit5: Literal[MyEnum.A],
|
||||
):
|
||||
reveal_type(promote(lit1)) # revealed: list[str]
|
||||
reveal_type(promote(lit2)) # revealed: list[str]
|
||||
reveal_type(promote(lit3)) # revealed: list[bool]
|
||||
reveal_type(promote(lit4)) # revealed: list[bytes]
|
||||
reveal_type(promote(lit5)) # revealed: list[MyEnum]
|
||||
```
|
||||
|
||||
Function types are also promoted to their `Callable` form:
|
||||
|
||||
```py
|
||||
def lit6(_: int) -> int:
|
||||
return 0
|
||||
|
||||
reveal_type(promote(lit6)) # revealed: list[(_: int) -> int]
|
||||
```
|
||||
|
||||
## Invariant collection literals are promoted
|
||||
|
||||
The elements of invariant collection literals, i.e. lists, dictionaries, and sets, are promoted:
|
||||
|
||||
```py
|
||||
reveal_type([1, 2, 3]) # revealed: list[Unknown | int]
|
||||
reveal_type({"a": 1, "b": 2, "c": 3}) # revealed: dict[Unknown | str, Unknown | int]
|
||||
reveal_type({"a", "b", "c"}) # revealed: set[Unknown | str]
|
||||
```
|
||||
|
||||
This promotion should not take place if the literal type appears in contravariant position:
|
||||
Covariant collection literals are not promoted:
|
||||
|
||||
```py
|
||||
reveal_type((1, 2, 3)) # revealed: tuple[Literal[1], Literal[2], Literal[3]]
|
||||
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[Literal[1, 2, 3]]
|
||||
```
|
||||
|
||||
## Invariant and contravariant return types are promoted
|
||||
|
||||
Literals are promoted if they are in non-covariant position in the return type of a generic
|
||||
function, or constructor of a generic class:
|
||||
|
||||
```py
|
||||
class Bivariant[T]:
|
||||
def __init__(self, value: T): ...
|
||||
|
||||
class Covariant[T]:
|
||||
def __init__(self, value: T): ...
|
||||
def pop(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
class Contravariant[T]:
|
||||
def __init__(self, value: T): ...
|
||||
def push(self, value: T) -> None:
|
||||
pass
|
||||
|
||||
class Invariant[T]:
|
||||
x: T
|
||||
|
||||
def __init__(self, value: T): ...
|
||||
|
||||
def f1[T](x: T) -> Bivariant[T] | None: ...
|
||||
def f2[T](x: T) -> Covariant[T] | None: ...
|
||||
def f3[T](x: T) -> Covariant[T] | Bivariant[T] | None: ...
|
||||
def f4[T](x: T) -> Contravariant[T] | None: ...
|
||||
def f5[T](x: T) -> Invariant[T] | None: ...
|
||||
def f6[T](x: T) -> Invariant[T] | Contravariant[T] | None: ...
|
||||
def f7[T](x: T) -> Covariant[T] | Contravariant[T] | None: ...
|
||||
def f8[T](x: T) -> Invariant[T] | Covariant[T] | None: ...
|
||||
def f9[T](x: T) -> tuple[Invariant[T], Invariant[T]] | None: ...
|
||||
def f10[T, U](x: T, y: U) -> tuple[Invariant[T], Covariant[U]] | None: ...
|
||||
def f11[T, U](x: T, y: U) -> tuple[Invariant[Covariant[T] | None], Covariant[U]] | None: ...
|
||||
|
||||
reveal_type(Bivariant(1)) # revealed: Bivariant[Literal[1]]
|
||||
reveal_type(Covariant(1)) # revealed: Covariant[Literal[1]]
|
||||
|
||||
reveal_type(Contravariant(1)) # revealed: Contravariant[int]
|
||||
reveal_type(Invariant(1)) # revealed: Invariant[int]
|
||||
|
||||
reveal_type(f1(1)) # revealed: Bivariant[Literal[1]] | None
|
||||
reveal_type(f2(1)) # revealed: Covariant[Literal[1]] | None
|
||||
reveal_type(f3(1)) # revealed: Covariant[Literal[1]] | Bivariant[Literal[1]] | None
|
||||
|
||||
reveal_type(f4(1)) # revealed: Contravariant[int] | None
|
||||
reveal_type(f5(1)) # revealed: Invariant[int] | None
|
||||
reveal_type(f6(1)) # revealed: Invariant[int] | Contravariant[int] | None
|
||||
reveal_type(f7(1)) # revealed: Covariant[int] | Contravariant[int] | None
|
||||
reveal_type(f8(1)) # revealed: Invariant[int] | Covariant[int] | None
|
||||
reveal_type(f9(1)) # revealed: tuple[Invariant[int], Invariant[int]] | None
|
||||
|
||||
reveal_type(f10(1, 1)) # revealed: tuple[Invariant[int], Covariant[Literal[1]]] | None
|
||||
reveal_type(f11(1, 1)) # revealed: tuple[Invariant[Covariant[int] | None], Covariant[Literal[1]]] | None
|
||||
```
|
||||
|
||||
## Invariant and contravariant literal arguments are respected
|
||||
|
||||
If a literal type is present in non-covariant position in the return type, but also in non-covariant
|
||||
position in an argument type, we respect the explicitly annotated argument, and avoid promotion:
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
class Covariant[T]:
|
||||
def pop(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
class Contravariant[T]:
|
||||
def push(self, value: T) -> None:
|
||||
pass
|
||||
|
||||
class Invariant[T]:
|
||||
x: T
|
||||
|
||||
def f1[T](x: T) -> Invariant[T] | None: ...
|
||||
def f2[T](x: Covariant[T]) -> Invariant[T] | None: ...
|
||||
def f3[T](x: Invariant[T]) -> Invariant[T] | None: ...
|
||||
def f4[T](x: Contravariant[T]) -> Invariant[T] | None: ...
|
||||
def f5[T](x: Covariant[Invariant[T]]) -> Invariant[T] | None: ...
|
||||
def f6[T](x: Covariant[Invariant[T]]) -> Invariant[T] | None: ...
|
||||
def f7[T](x: Covariant[T], y: Invariant[T]) -> Invariant[T] | None: ...
|
||||
def f8[T](x: Invariant[T], y: Covariant[T]) -> Invariant[T] | None: ...
|
||||
def f9[T](x: Covariant[T], y: Contravariant[T]) -> Invariant[T] | None: ...
|
||||
def f10[T](x: Contravariant[T], y: Covariant[T]) -> Invariant[T] | None: ...
|
||||
def _(
|
||||
lit: Literal[1],
|
||||
cov: Covariant[Literal[1]],
|
||||
inv: Invariant[Literal[1]],
|
||||
cont: Contravariant[Literal[1]],
|
||||
inv2: Covariant[Invariant[Literal[1]]],
|
||||
):
|
||||
reveal_type(f1(lit)) # revealed: Invariant[int] | None
|
||||
reveal_type(f2(cov)) # revealed: Invariant[int] | None
|
||||
|
||||
reveal_type(f3(inv)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f4(cont)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f5(inv2)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f6(inv2)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f7(cov, inv)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f8(inv, cov)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f9(cov, cont)) # revealed: Invariant[Literal[1]] | None
|
||||
reveal_type(f10(cont, cov)) # revealed: Invariant[Literal[1]] | None
|
||||
```
|
||||
|
||||
Note that we consider variance of _argument_ types, not parameters. If the literal is in covariant
|
||||
position in the declared parameter type, but invariant in the argument type, we still avoid
|
||||
promotion:
|
||||
|
||||
```py
|
||||
from typing import Iterable
|
||||
|
||||
class X[T]:
|
||||
def __init__(self, x: Iterable[T]): ...
|
||||
|
||||
def _(x: list[Literal[1]]):
|
||||
reveal_type(X(x)) # revealed: X[Literal[1]]
|
||||
```
|
||||
|
||||
## Literals are promoted recursively
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def promote[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def _(x: tuple[tuple[tuple[Literal[1]]]]):
|
||||
reveal_type(promote(x)) # revealed: list[tuple[tuple[tuple[int]]]]
|
||||
|
||||
x1 = ([1, 2], [(3,), (4,)], ["5", "6"])
|
||||
reveal_type(x1) # revealed: tuple[list[Unknown | int], list[Unknown | tuple[int]], list[Unknown | str]]
|
||||
```
|
||||
|
||||
However, this promotion should not take place if the literal type appears in contravariant position:
|
||||
|
||||
```py
|
||||
from typing import Callable, Literal
|
||||
|
|
@ -66,3 +248,93 @@ def _(
|
|||
reveal_type([contravariant]) # revealed: list[Unknown | Contravariant[Literal[1]]]
|
||||
reveal_type([invariant]) # revealed: list[Unknown | Invariant[Literal[1]]]
|
||||
```
|
||||
|
||||
## Literal annnotations are respected
|
||||
|
||||
Explicitly annotated `Literal` types will prevent literal promotion:
|
||||
|
||||
```py
|
||||
from enum import Enum
|
||||
from typing_extensions import Literal, LiteralString
|
||||
|
||||
class Color(Enum):
|
||||
RED = "red"
|
||||
|
||||
type Y[T] = list[T]
|
||||
|
||||
class X[T]:
|
||||
value: T
|
||||
|
||||
def __init__(self, value: T): ...
|
||||
|
||||
def x[T](x: T) -> X[T]:
|
||||
return X(x)
|
||||
|
||||
x1: list[Literal[1]] = [1]
|
||||
reveal_type(x1) # revealed: list[Literal[1]]
|
||||
|
||||
x2: list[Literal[True]] = [True]
|
||||
reveal_type(x2) # revealed: list[Literal[True]]
|
||||
|
||||
x3: list[Literal["a"]] = ["a"]
|
||||
reveal_type(x3) # revealed: list[Literal["a"]]
|
||||
|
||||
x4: list[LiteralString] = ["a", "b", "c"]
|
||||
reveal_type(x4) # revealed: list[LiteralString]
|
||||
|
||||
x5: list[list[Literal[1]]] = [[1]]
|
||||
reveal_type(x5) # revealed: list[list[Literal[1]]]
|
||||
|
||||
x6: dict[list[Literal[1]], list[Literal[Color.RED]]] = {[1]: [Color.RED, Color.RED]}
|
||||
reveal_type(x6) # revealed: dict[list[Literal[1]], list[Color]]
|
||||
|
||||
x7: X[Literal[1]] = X(1)
|
||||
reveal_type(x7) # revealed: X[Literal[1]]
|
||||
|
||||
x8: X[int] = X(1)
|
||||
reveal_type(x8) # revealed: X[int]
|
||||
|
||||
x9: dict[list[X[Literal[1]]], set[Literal[b"a"]]] = {[X(1)]: {b"a"}}
|
||||
reveal_type(x9) # revealed: dict[list[X[Literal[1]]], set[Literal[b"a"]]]
|
||||
|
||||
x10: list[Literal[1, 2, 3]] = [1, 2, 3]
|
||||
reveal_type(x10) # revealed: list[Literal[1, 2, 3]]
|
||||
|
||||
x11: list[Literal[1] | Literal[2] | Literal[3]] = [1, 2, 3]
|
||||
reveal_type(x11) # revealed: list[Literal[1, 2, 3]]
|
||||
|
||||
x12: Y[Y[Literal[1]]] = [[1]]
|
||||
reveal_type(x12) # revealed: list[Y[Literal[1]]]
|
||||
|
||||
x13: list[tuple[Literal[1], Literal[2], Literal[3]]] = [(1, 2, 3)]
|
||||
reveal_type(x13) # revealed: list[tuple[Literal[1], Literal[2], Literal[3]]]
|
||||
|
||||
x14: list[tuple[int, str, int]] = [(1, "2", 3), (4, "5", 6)]
|
||||
reveal_type(x14) # revealed: list[tuple[int, str, int]]
|
||||
|
||||
x15: list[tuple[Literal[1], ...]] = [(1, 1, 1)]
|
||||
reveal_type(x15) # revealed: list[tuple[Literal[1], ...]]
|
||||
|
||||
x16: list[tuple[int, ...]] = [(1, 1, 1)]
|
||||
reveal_type(x16) # revealed: list[tuple[int, ...]]
|
||||
|
||||
x17: list[int | Literal[1]] = [1]
|
||||
reveal_type(x17) # revealed: list[int]
|
||||
|
||||
x18: list[Literal[1, 2, 3, 4]] = [1, 2]
|
||||
reveal_type(x18) # revealed: list[Literal[1, 2, 3, 4]]
|
||||
|
||||
x19: list[Literal[1]]
|
||||
|
||||
x19 = [1]
|
||||
reveal_type(x19) # revealed: list[Literal[1]]
|
||||
|
||||
(x19 := [1])
|
||||
reveal_type(x19) # revealed: list[Literal[1]]
|
||||
|
||||
x20: list[Literal[1]] | None = [1]
|
||||
reveal_type(x20) # revealed: list[Literal[1]]
|
||||
|
||||
x21: X[Literal[1]] | None = x(1)
|
||||
reveal_type(x21) # revealed: X[Literal[1]]
|
||||
```
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ dd: defaultdict[int, int] = defaultdict(int)
|
|||
dd[0] = 0
|
||||
cm: ChainMap[int, int] = ChainMap({1: 1}, {0: 0})
|
||||
cm[0] = 0
|
||||
reveal_type(cm) # revealed: ChainMap[int | Unknown, int | Unknown]
|
||||
reveal_type(cm) # revealed: ChainMap[int, int]
|
||||
|
||||
reveal_type(l[0]) # revealed: Literal[0]
|
||||
reveal_type(d[0]) # revealed: Literal[0]
|
||||
|
|
|
|||
|
|
@ -514,10 +514,8 @@ For covariant types, such as `frozenset`, the ideal behaviour would be to not pr
|
|||
types to their instance supertypes: doing so causes more false positives than it fixes:
|
||||
|
||||
```py
|
||||
# TODO: better here would be `frozenset[Literal[1, 2, 3]]`
|
||||
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[int]
|
||||
# TODO: better here would be `frozenset[tuple[Literal[1], Literal[2], Literal[3]]]`
|
||||
reveal_type(frozenset(((1, 2, 3),))) # revealed: frozenset[tuple[int, int, int]]
|
||||
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[Literal[1, 2, 3]]
|
||||
reveal_type(frozenset(((1, 2, 3),))) # revealed: frozenset[tuple[Literal[1], Literal[2], Literal[3]]]
|
||||
```
|
||||
|
||||
Literals are always promoted for invariant containers such as `list`, however, even though this can
|
||||
|
|
|
|||
|
|
@ -126,21 +126,21 @@ Also, the value types declared in a `TypedDict` affect generic call infere
|
|||
|
||||
```py
|
||||
class Plot(TypedDict):
|
||||
y: list[int]
|
||||
x: list[int] | None
|
||||
y: list[int | None]
|
||||
x: list[int | None] | None
|
||||
|
||||
plot1: Plot = {"y": [1, 2, 3], "x": None}
|
||||
|
||||
def homogeneous_list[T](*args: T) -> list[T]:
|
||||
return list(args)
|
||||
|
||||
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
|
||||
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[int]
|
||||
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
|
||||
reveal_type(plot2["y"]) # revealed: list[int]
|
||||
reveal_type(plot2["y"]) # revealed: list[int | None]
|
||||
|
||||
plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
|
||||
reveal_type(plot3["y"]) # revealed: list[int]
|
||||
reveal_type(plot3["x"]) # revealed: list[int] | None
|
||||
reveal_type(plot3["y"]) # revealed: list[int | None]
|
||||
reveal_type(plot3["x"]) # revealed: list[int | None] | None
|
||||
|
||||
Y = "y"
|
||||
X = "x"
|
||||
|
|
|
|||
|
|
@ -243,6 +243,10 @@ pub(crate) type TryBoolVisitor<'db> =
|
|||
CycleDetector<TryBool, Type<'db>, Result<Truthiness, BoolError<'db>>>;
|
||||
pub(crate) struct TryBool;
|
||||
|
||||
/// A [`CycleDetector`] that is used in `visit_specialization` methods.
|
||||
pub(crate) type SpecializationVisitor<'db> = CycleDetector<VisitSpecialization, Type<'db>, ()>;
|
||||
pub(crate) struct VisitSpecialization;
|
||||
|
||||
/// A [`TypeTransformer`] that is used in `normalized` methods.
|
||||
pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>;
|
||||
|
||||
|
|
@ -3376,6 +3380,81 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Recursively visit the specialization of a generic class instance.
|
||||
///
|
||||
/// The provided closure will be called with each assignment of a type variable present in this
|
||||
/// type, along with the variance of the outermost type with respect to the type variable.
|
||||
///
|
||||
/// If a `TypeContext` is provided, it will be narrowed as nested types are visited, if the
|
||||
/// type is a specialized instance of the same class.
|
||||
pub(crate) fn visit_specialization<F>(self, db: &'db dyn Db, tcx: TypeContext<'db>, mut f: F)
|
||||
where
|
||||
F: FnMut(BoundTypeVarInstance<'db>, Type<'db>, TypeVarVariance, TypeContext<'db>),
|
||||
{
|
||||
self.visit_specialization_impl(
|
||||
db,
|
||||
tcx,
|
||||
TypeVarVariance::Covariant,
|
||||
&mut f,
|
||||
&SpecializationVisitor::default(),
|
||||
);
|
||||
}
|
||||
|
||||
fn visit_specialization_impl(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
tcx: TypeContext<'db>,
|
||||
polarity: TypeVarVariance,
|
||||
f: &mut dyn FnMut(BoundTypeVarInstance<'db>, Type<'db>, TypeVarVariance, TypeContext<'db>),
|
||||
visitor: &SpecializationVisitor<'db>,
|
||||
) {
|
||||
let Type::NominalInstance(instance) = self else {
|
||||
match self {
|
||||
Type::Union(union) => {
|
||||
for element in union.elements(db) {
|
||||
element.visit_specialization_impl(db, tcx, polarity, f, visitor);
|
||||
}
|
||||
}
|
||||
Type::Intersection(intersection) => {
|
||||
for element in intersection.positive(db) {
|
||||
element.visit_specialization_impl(db, tcx, polarity, f, visitor);
|
||||
}
|
||||
}
|
||||
Type::TypeAlias(alias) => visitor.visit(self, || {
|
||||
alias
|
||||
.value_type(db)
|
||||
.visit_specialization_impl(db, tcx, polarity, f, visitor);
|
||||
}),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
let (class_literal, Some(specialization)) = instance.class(db).class_literal(db) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let tcx_specialization = tcx
|
||||
.annotation
|
||||
.and_then(|tcx| tcx.specialization_of(db, class_literal));
|
||||
|
||||
for (typevar, ty) in specialization
|
||||
.generic_context(db)
|
||||
.variables(db)
|
||||
.zip(specialization.types(db))
|
||||
{
|
||||
let variance = typevar.variance_with_polarity(db, polarity);
|
||||
let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar)));
|
||||
|
||||
f(typevar, *ty, variance, tcx);
|
||||
|
||||
visitor.visit(*ty, || {
|
||||
ty.visit_specialization_impl(db, tcx, variance, f, visitor);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Return true if there is just a single inhabitant for this type.
|
||||
///
|
||||
/// Note: This function aims to have no false positives, but might return `false`
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use std::fmt;
|
|||
use itertools::{Either, Itertools};
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use ruff_python_ast::name::Name;
|
||||
use rustc_hash::FxHashSet;
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use smallvec::{SmallVec, smallvec, smallvec_inline};
|
||||
|
||||
use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Signature, Type};
|
||||
|
|
@ -38,8 +38,8 @@ use crate::types::{
|
|||
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
|
||||
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
|
||||
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
||||
TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support,
|
||||
infer_isolated_expression, todo_type,
|
||||
TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind,
|
||||
enums, ide_support, infer_isolated_expression, todo_type,
|
||||
};
|
||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
|
||||
|
|
@ -2808,6 +2808,11 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
Some(builder.type_mappings().clone())
|
||||
});
|
||||
|
||||
// For a given type variable, we track the variance of any assignments to that type variable
|
||||
// in the argument types.
|
||||
let mut variance_in_arguments: FxHashMap<BoundTypeVarIdentity<'_>, TypeVarVariance> =
|
||||
FxHashMap::default();
|
||||
|
||||
let parameters = self.signature.parameters();
|
||||
for (argument_index, adjusted_argument_index, _, argument_type) in
|
||||
self.enumerate_argument_types()
|
||||
|
|
@ -2820,22 +2825,32 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
continue;
|
||||
};
|
||||
|
||||
let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| {
|
||||
// Avoid widening the inferred type if it is already assignable to the
|
||||
// preferred declared type.
|
||||
preferred_type_mappings
|
||||
.as_ref()
|
||||
.and_then(|types| types.get(&declared_ty))
|
||||
.is_none_or(|preferred_ty| {
|
||||
!inferred_ty.is_assignable_to(self.db, *preferred_ty)
|
||||
})
|
||||
};
|
||||
|
||||
if let Err(error) = builder.infer_filter(
|
||||
let specialization_result = builder.infer_map(
|
||||
expected_type,
|
||||
variadic_argument_type.unwrap_or(argument_type),
|
||||
filter,
|
||||
) {
|
||||
|(identity, variance, inferred_ty)| {
|
||||
// Avoid widening the inferred type if it is already assignable to the
|
||||
// preferred declared type.
|
||||
if preferred_type_mappings
|
||||
.as_ref()
|
||||
.and_then(|types| types.get(&identity))
|
||||
.is_some_and(|preferred_ty| {
|
||||
inferred_ty.is_assignable_to(self.db, *preferred_ty)
|
||||
})
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
variance_in_arguments
|
||||
.entry(identity)
|
||||
.and_modify(|current| *current = current.join(variance))
|
||||
.or_insert(variance);
|
||||
|
||||
Some(inferred_ty)
|
||||
},
|
||||
);
|
||||
|
||||
if let Err(error) = specialization_result {
|
||||
self.errors.push(BindingError::SpecializationError {
|
||||
error,
|
||||
argument_index: adjusted_argument_index,
|
||||
|
|
@ -2844,8 +2859,59 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
}
|
||||
|
||||
// Attempt to promote any literal types assigned to the specialization.
|
||||
let maybe_promote = |identity, typevar, ty: Type<'db>| {
|
||||
let Some(return_ty) = self.constructor_instance_type.or(self.signature.return_ty)
|
||||
else {
|
||||
return ty;
|
||||
};
|
||||
|
||||
let mut combined_tcx = TypeContext::default();
|
||||
let mut variance_in_return = TypeVarVariance::Bivariant;
|
||||
|
||||
// Find all occurrences of the type variable in the return type.
|
||||
let visit_return_ty = |_, ty, variance, tcx: TypeContext<'db>| {
|
||||
if ty != Type::TypeVar(typevar) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We always prefer the declared type when attempting literal promotion,
|
||||
// so we take the union of every applicable type context.
|
||||
match (tcx.annotation, &mut combined_tcx.annotation) {
|
||||
(Some(_), None) => combined_tcx = tcx,
|
||||
(Some(ty), Some(combined_ty)) => {
|
||||
*combined_ty = UnionType::from_elements(self.db, [*combined_ty, ty]);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
variance_in_return = variance_in_return.join(variance);
|
||||
};
|
||||
|
||||
return_ty.visit_specialization(self.db, self.call_expression_tcx, visit_return_ty);
|
||||
|
||||
// Promotion is only useful if the type variable is in invariant or contravariant
|
||||
// position in the return type.
|
||||
if variance_in_return.is_covariant() {
|
||||
return ty;
|
||||
}
|
||||
|
||||
// If the type variable is a non-covariant position in the argument, then we avoid
|
||||
// promotion, respecting any literals in the parameter type.
|
||||
if variance_in_arguments
|
||||
.get(&identity)
|
||||
.is_some_and(|variance| !variance.is_covariant())
|
||||
{
|
||||
return ty;
|
||||
}
|
||||
|
||||
ty.promote_literals(self.db, combined_tcx)
|
||||
};
|
||||
|
||||
// Build the specialization first without inferring the complete type context.
|
||||
let isolated_specialization = builder.build(generic_context, self.call_expression_tcx);
|
||||
let isolated_specialization = builder
|
||||
.mapped(generic_context, maybe_promote)
|
||||
.build(generic_context);
|
||||
let isolated_return_ty = self
|
||||
.return_ty
|
||||
.apply_specialization(self.db, isolated_specialization);
|
||||
|
|
@ -2870,7 +2936,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
builder.infer(return_ty, call_expression_tcx).ok()?;
|
||||
|
||||
// Otherwise, build the specialization again after inferring the complete type context.
|
||||
let specialization = builder.build(generic_context, self.call_expression_tcx);
|
||||
let specialization = builder
|
||||
.mapped(generic_context, maybe_promote)
|
||||
.build(generic_context);
|
||||
let return_ty = return_ty.apply_specialization(self.db, specialization);
|
||||
|
||||
Some((Some(specialization), return_ty))
|
||||
|
|
|
|||
|
|
@ -1476,14 +1476,9 @@ impl<'db> ClassLiteral<'db> {
|
|||
visitor.typevars.into_inner()
|
||||
}
|
||||
|
||||
/// Returns the generic context that should be inherited by any constructor methods of this
|
||||
/// class.
|
||||
///
|
||||
/// When inferring a specialization of the class's generic context from a constructor call, we
|
||||
/// promote any typevars that are inferred as a literal to the corresponding instance type.
|
||||
/// Returns the generic context that should be inherited by any constructor methods of this class.
|
||||
fn inherited_generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
|
||||
self.generic_context(db)
|
||||
.map(|generic_context| generic_context.promote_literals(db))
|
||||
}
|
||||
|
||||
pub(super) fn file(self, db: &dyn Db) -> File {
|
||||
|
|
|
|||
|
|
@ -181,26 +181,6 @@ impl<'a, 'db> InferableTypeVars<'a, 'db> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)]
|
||||
pub struct GenericContextTypeVar<'db> {
|
||||
bound_typevar: BoundTypeVarInstance<'db>,
|
||||
should_promote_literals: bool,
|
||||
}
|
||||
|
||||
impl<'db> GenericContextTypeVar<'db> {
|
||||
fn new(bound_typevar: BoundTypeVarInstance<'db>) -> Self {
|
||||
Self {
|
||||
bound_typevar,
|
||||
should_promote_literals: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn promote_literals(mut self) -> Self {
|
||||
self.should_promote_literals = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A list of formal type variables for a generic function, class, or type alias.
|
||||
///
|
||||
/// # Ordering
|
||||
|
|
@ -210,7 +190,7 @@ impl<'db> GenericContextTypeVar<'db> {
|
|||
#[derive(PartialOrd, Ord)]
|
||||
pub struct GenericContext<'db> {
|
||||
#[returns(ref)]
|
||||
variables_inner: FxOrderMap<BoundTypeVarIdentity<'db>, GenericContextTypeVar<'db>>,
|
||||
variables_inner: FxOrderMap<BoundTypeVarIdentity<'db>, BoundTypeVarInstance<'db>>,
|
||||
}
|
||||
|
||||
pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
|
||||
|
|
@ -227,19 +207,6 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S
|
|||
impl get_size2::GetSize for GenericContext<'_> {}
|
||||
|
||||
impl<'db> GenericContext<'db> {
|
||||
fn from_variables(
|
||||
db: &'db dyn Db,
|
||||
variables: impl IntoIterator<Item = GenericContextTypeVar<'db>>,
|
||||
) -> Self {
|
||||
Self::new_internal(
|
||||
db,
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|variable| (variable.bound_typevar.identity(db), variable))
|
||||
.collect::<FxOrderMap<_, _>>(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a generic context from a list of PEP-695 type parameters.
|
||||
pub(crate) fn from_type_params(
|
||||
db: &'db dyn Db,
|
||||
|
|
@ -259,24 +226,19 @@ impl<'db> GenericContext<'db> {
|
|||
db: &'db dyn Db,
|
||||
type_params: impl IntoIterator<Item = BoundTypeVarInstance<'db>>,
|
||||
) -> Self {
|
||||
Self::from_variables(db, type_params.into_iter().map(GenericContextTypeVar::new))
|
||||
}
|
||||
|
||||
/// Returns a copy of this generic context where we will promote literal types in any inferred
|
||||
/// specializations.
|
||||
pub(crate) fn promote_literals(self, db: &'db dyn Db) -> Self {
|
||||
Self::from_variables(
|
||||
Self::new_internal(
|
||||
db,
|
||||
self.variables_inner(db)
|
||||
.values()
|
||||
.map(|variable| variable.promote_literals()),
|
||||
type_params
|
||||
.into_iter()
|
||||
.map(|variable| (variable.identity(db), variable))
|
||||
.collect::<FxOrderMap<_, _>>(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Merge this generic context with another, returning a new generic context that
|
||||
/// contains type variables from both contexts.
|
||||
pub(crate) fn merge(self, db: &'db dyn Db, other: Self) -> Self {
|
||||
Self::from_variables(
|
||||
Self::from_typevar_instances(
|
||||
db,
|
||||
self.variables_inner(db)
|
||||
.values()
|
||||
|
|
@ -339,9 +301,7 @@ impl<'db> GenericContext<'db> {
|
|||
self,
|
||||
db: &'db dyn Db,
|
||||
) -> impl ExactSizeIterator<Item = BoundTypeVarInstance<'db>> + Clone {
|
||||
self.variables_inner(db)
|
||||
.values()
|
||||
.map(|variable| variable.bound_typevar)
|
||||
self.variables_inner(db).values().copied()
|
||||
}
|
||||
|
||||
fn variable_from_type_param(
|
||||
|
|
@ -611,7 +571,7 @@ impl<'db> GenericContext<'db> {
|
|||
}
|
||||
|
||||
fn heap_size(
|
||||
(variables,): &(FxOrderMap<BoundTypeVarIdentity<'db>, GenericContextTypeVar<'db>>,),
|
||||
(variables,): &(FxOrderMap<BoundTypeVarIdentity<'db>, BoundTypeVarInstance<'db>>,),
|
||||
) -> usize {
|
||||
ruff_memory_usage::order_map_heap_size(variables)
|
||||
}
|
||||
|
|
@ -1307,6 +1267,10 @@ pub(crate) struct SpecializationBuilder<'db> {
|
|||
types: FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>>,
|
||||
}
|
||||
|
||||
/// An assignment from a bound type variable to a given type, along with the variance of the outermost
|
||||
/// type with respect to the type variable.
|
||||
pub(crate) type TypeVarAssignment<'db> = (BoundTypeVarIdentity<'db>, TypeVarVariance, Type<'db>);
|
||||
|
||||
impl<'db> SpecializationBuilder<'db> {
|
||||
pub(crate) fn new(db: &'db dyn Db, inferable: InferableTypeVars<'db, 'db>) -> Self {
|
||||
Self {
|
||||
|
|
@ -1321,31 +1285,31 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
&self.types
|
||||
}
|
||||
|
||||
pub(crate) fn build(
|
||||
&mut self,
|
||||
/// Map the types that have been assigned in this specialization.
|
||||
pub(crate) fn mapped(
|
||||
&self,
|
||||
generic_context: GenericContext<'db>,
|
||||
tcx: TypeContext<'db>,
|
||||
) -> Specialization<'db> {
|
||||
let tcx_specialization = tcx
|
||||
.annotation
|
||||
.and_then(|annotation| annotation.class_specialization(self.db));
|
||||
f: impl Fn(BoundTypeVarIdentity<'db>, BoundTypeVarInstance<'db>, Type<'db>) -> Type<'db>,
|
||||
) -> Self {
|
||||
let mut types = self.types.clone();
|
||||
for (identity, variable) in generic_context.variables_inner(self.db) {
|
||||
if let Some(ty) = types.get_mut(identity) {
|
||||
*ty = f(*identity, *variable, *ty);
|
||||
}
|
||||
}
|
||||
|
||||
let types =
|
||||
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
|
||||
let mut ty = self.types.get(identity).copied();
|
||||
Self {
|
||||
db: self.db,
|
||||
inferable: self.inferable,
|
||||
types,
|
||||
}
|
||||
}
|
||||
|
||||
// When inferring a specialization for a generic class typevar from a constructor call,
|
||||
// promote any typevars that are inferred as a literal to the corresponding instance type.
|
||||
if variable.should_promote_literals {
|
||||
let tcx = tcx_specialization.and_then(|specialization| {
|
||||
specialization.get(self.db, variable.bound_typevar)
|
||||
});
|
||||
|
||||
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
|
||||
}
|
||||
|
||||
ty
|
||||
});
|
||||
pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
|
||||
let types = generic_context
|
||||
.variables_inner(self.db)
|
||||
.iter()
|
||||
.map(|(identity, _)| self.types.get(identity).copied());
|
||||
|
||||
// TODO Infer the tuple spec for a tuple type
|
||||
generic_context.specialize_partial(self.db, types)
|
||||
|
|
@ -1355,14 +1319,17 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
&mut self,
|
||||
bound_typevar: BoundTypeVarInstance<'db>,
|
||||
ty: Type<'db>,
|
||||
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
|
||||
variance: TypeVarVariance,
|
||||
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
|
||||
) {
|
||||
let identity = bound_typevar.identity(self.db);
|
||||
let Some(ty) = f((identity, variance, ty)) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match self.types.entry(identity) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
if filter(identity, ty) {
|
||||
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
|
||||
}
|
||||
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(ty);
|
||||
|
|
@ -1376,18 +1343,28 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
self.infer_filter(formal, actual, |_, _| true)
|
||||
self.infer_map(formal, actual, |(_, _, ty)| Some(ty))
|
||||
}
|
||||
|
||||
/// Infer type mappings for the specialization based on a given type and its declared type.
|
||||
///
|
||||
/// The filter predicate is provided with a type variable and the type being mapped to it. Type
|
||||
/// mappings to which the predicate returns `false` will be ignored.
|
||||
pub(crate) fn infer_filter(
|
||||
/// The provided function will be called before any type mappings are created, and can
|
||||
/// optionally modify the inferred type, or filter out the type mapping entirely.
|
||||
pub(crate) fn infer_map(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool,
|
||||
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
self.infer_map_impl(formal, actual, TypeVarVariance::Covariant, &mut f)
|
||||
}
|
||||
|
||||
fn infer_map_impl(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
polarity: TypeVarVariance,
|
||||
mut f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
if formal == actual {
|
||||
return Ok(());
|
||||
|
|
@ -1445,7 +1422,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
if remaining_actual.is_never() {
|
||||
return Ok(());
|
||||
}
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
|
||||
self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f);
|
||||
}
|
||||
(Type::Union(formal), _) => {
|
||||
// Second, if the formal is a union, and precisely one union element is assignable
|
||||
|
|
@ -1475,7 +1452,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let bound_typevars =
|
||||
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
|
||||
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
|
||||
self.add_type_mapping(bound_typevar, actual, filter);
|
||||
self.add_type_mapping(bound_typevar, actual, polarity, f);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1485,7 +1462,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
// actual type must also be disjoint from every negative element of the
|
||||
// intersection, but that doesn't help us infer any type mappings.)
|
||||
for positive in formal.iter_positive(self.db) {
|
||||
self.infer(positive, actual)?;
|
||||
self.infer_map_impl(positive, actual, polarity, f)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1503,13 +1480,13 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
argument: ty,
|
||||
});
|
||||
}
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
self.add_type_mapping(bound_typevar, ty, polarity, f);
|
||||
}
|
||||
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
|
||||
// Prefer an exact match first.
|
||||
for constraint in constraints.elements(self.db) {
|
||||
if ty == *constraint {
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
self.add_type_mapping(bound_typevar, ty, polarity, f);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -1519,7 +1496,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
.when_assignable_to(self.db, *constraint, self.inferable)
|
||||
.is_always_satisfied(self.db)
|
||||
{
|
||||
self.add_type_mapping(bound_typevar, *constraint, filter);
|
||||
self.add_type_mapping(bound_typevar, *constraint, polarity, f);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
|
@ -1529,7 +1506,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
});
|
||||
}
|
||||
_ => {
|
||||
self.add_type_mapping(bound_typevar, ty, filter);
|
||||
self.add_type_mapping(bound_typevar, ty, polarity, f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1554,7 +1531,8 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
for (formal_element, actual_element) in
|
||||
formal_tuple.all_elements().zip(actual_tuple.all_elements())
|
||||
{
|
||||
self.infer(*formal_element, *actual_element)?;
|
||||
let variance = TypeVarVariance::Covariant.compose(polarity);
|
||||
self.infer_map_impl(*formal_element, *actual_element, variance, &mut f)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -1584,13 +1562,20 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
if formal_origin != base_alias.origin(self.db) {
|
||||
continue;
|
||||
}
|
||||
let generic_context = formal_alias
|
||||
.specialization(self.db)
|
||||
.generic_context(self.db)
|
||||
.variables(self.db);
|
||||
let formal_specialization =
|
||||
formal_alias.specialization(self.db).types(self.db);
|
||||
let base_specialization = base_alias.specialization(self.db).types(self.db);
|
||||
for (formal_ty, base_ty) in
|
||||
formal_specialization.iter().zip(base_specialization)
|
||||
{
|
||||
self.infer(*formal_ty, *base_ty)?;
|
||||
for (typevar, formal_ty, base_ty) in itertools::izip!(
|
||||
generic_context,
|
||||
formal_specialization,
|
||||
base_specialization
|
||||
) {
|
||||
let variance = typevar.variance_with_polarity(self.db, polarity);
|
||||
self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6707,7 +6707,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
SpecializationBuilder::new(db, generic_context.inferable_typevars(db));
|
||||
|
||||
let _ = builder.infer(return_ty, declared_return_ty);
|
||||
let specialization = builder.build(generic_context, call_expression_tcx);
|
||||
let specialization = builder.build(generic_context);
|
||||
|
||||
parameter_type = parameter_type.apply_specialization(db, specialization);
|
||||
}
|
||||
|
|
@ -7394,9 +7394,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
let class_type = class_literal.apply_specialization(self.db(), |_| {
|
||||
builder.build(generic_context, TypeContext::default())
|
||||
});
|
||||
let class_type =
|
||||
class_literal.apply_specialization(self.db(), |_| builder.build(generic_context));
|
||||
|
||||
Type::from(class_type).to_instance(self.db())
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue