[red-knot] optimize building large unions of literals (#17403)

## Summary

Special-case literal types in `UnionBuilder` to speed up building large
unions of literals.

This optimization is extremely effective at speeding up building even a
very large union (it improves the large-unions benchmark by 41x!). The
problem we can run into is that it is easy to then run into another
operation on the very large union (for instance, narrowing may add it to
an intersection, which then distributes it over the intersection) which
is still slow.

I think it is possible to avoid this by extending this optimized
"grouped" representation throughout not just `UnionBuilder`, but all of
our union and intersection representations. I have some work in this
direction, but rather than spending more time on it right now, I'd
rather just land this much, along with a limit on the size of these
unions (to avoid building really big unions quickly and then hitting
issues where they are used.)

## Test Plan

Existing tests and benchmarks.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Carl Meyer 2025-04-16 06:55:37 -07:00 committed by GitHub
parent 13ea4e5d0e
commit a1f361949e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 148 additions and 22 deletions

View File

@ -68,7 +68,7 @@ def x(
a3: Literal[Literal["w"], Literal["r"], Literal[Literal["w+"]]], a3: Literal[Literal["w"], Literal["r"], Literal[Literal["w+"]]],
a4: Literal[True] | Literal[1, 2] | Literal["foo"], a4: Literal[True] | Literal[1, 2] | Literal["foo"],
): ):
reveal_type(a1) # revealed: Literal[1, 2, 3, "foo", 5] | None reveal_type(a1) # revealed: Literal[1, 2, 3, 5, "foo"] | None
reveal_type(a2) # revealed: Literal["w", "r"] reveal_type(a2) # revealed: Literal["w", "r"]
reveal_type(a3) # revealed: Literal["w", "r", "w+"] reveal_type(a3) # revealed: Literal["w", "r", "w+"]
reveal_type(a4) # revealed: Literal[True, 1, 2, "foo"] reveal_type(a4) # revealed: Literal[True, 1, 2, "foo"]
@ -108,7 +108,7 @@ def union_example(
None, None,
], ],
): ):
reveal_type(x) # revealed: Unknown | Literal[-1, "A", b"A", b"\x00", b"\x07", 0, 1, "B", "foo", "bar", True] | None reveal_type(x) # revealed: Unknown | Literal[-1, 0, 1, "A", "B", "foo", "bar", b"A", b"\x00", b"\x07", True] | None
``` ```
## Detecting Literal outside typing and typing_extensions ## Detecting Literal outside typing and typing_extensions

View File

@ -105,7 +105,7 @@ def f1(
from typing import Literal from typing import Literal
def f(v: Literal["a", r"b", b"c", "d" "e", "\N{LATIN SMALL LETTER F}", "\x67", """h"""]): def f(v: Literal["a", r"b", b"c", "d" "e", "\N{LATIN SMALL LETTER F}", "\x67", """h"""]):
reveal_type(v) # revealed: Literal["a", "b", b"c", "de", "f", "g", "h"] reveal_type(v) # revealed: Literal["a", "b", "de", "f", "g", "h", b"c"]
``` ```
## Class variables ## Class variables

View File

@ -12,6 +12,11 @@
//! flattens into the outer one), intersections cannot contain other intersections (also //! flattens into the outer one), intersections cannot contain other intersections (also
//! flattens), and intersections cannot contain unions (the intersection distributes over the //! flattens), and intersections cannot contain unions (the intersection distributes over the
//! union, inverting it into a union-of-intersections). //! union, inverting it into a union-of-intersections).
//! * No type in a union can be a subtype of any other type in the union (just eliminate the
//! subtype from the union).
//! * No type in an intersection can be a supertype of any other type in the intersection (just
//! eliminate the supertype from the intersection).
//! * An intersection containing two non-overlapping types simplifies to [`Type::Never`].
//! //!
//! The implication of these invariants is that a [`UnionBuilder`] does not necessarily build a //! The implication of these invariants is that a [`UnionBuilder`] does not necessarily build a
//! [`Type::Union`]. For example, if only one type is added to the [`UnionBuilder`], `build()` will //! [`Type::Union`]. For example, if only one type is added to the [`UnionBuilder`], `build()` will
@ -19,19 +24,35 @@
//! union type is added to the intersection, it will distribute and [`IntersectionBuilder::build`] //! union type is added to the intersection, it will distribute and [`IntersectionBuilder::build`]
//! may end up returning a [`Type::Union`] of intersections. //! may end up returning a [`Type::Union`] of intersections.
//! //!
//! In the future we should have these additional invariants, but they aren't implemented yet: //! ## Performance
//! * No type in a union can be a subtype of any other type in the union (just eliminate the //!
//! subtype from the union). //! In practice, there are two kinds of unions found in the wild: relatively-small unions made up
//! * No type in an intersection can be a supertype of any other type in the intersection (just //! of normal user types (classes, etc), and large unions made up of literals, which can occur via
//! eliminate the supertype from the intersection). //! large enums (not yet implemented) or from string/integer/bytes literals, which can grow due to
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. //! literal arithmetic or operations on literal strings/bytes. For normal unions, it's most
//! efficient to just store the member types in a vector, and do O(n^2) `is_subtype_of` checks to
//! maintain the union in simplified form. But literal unions can grow to a size where this becomes
//! a performance problem. For this reason, we group literal types in `UnionBuilder`. Since every
//! different string literal type shares exactly the same possible super-types, and none of them
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
//! unnecessary `is_subtype_of` checks.
use crate::types::{IntersectionType, KnownClass, Type, TypeVarBoundOrConstraints, UnionType}; use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
TypeVarBoundOrConstraints, UnionType,
};
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
use smallvec::SmallVec; use smallvec::SmallVec;
enum UnionElement<'db> {
IntLiterals(FxOrderSet<i64>),
StringLiterals(FxOrderSet<StringLiteralType<'db>>),
BytesLiterals(FxOrderSet<BytesLiteralType<'db>>),
Type(Type<'db>),
}
pub(crate) struct UnionBuilder<'db> { pub(crate) struct UnionBuilder<'db> {
elements: Vec<Type<'db>>, elements: Vec<UnionElement<'db>>,
db: &'db dyn Db, db: &'db dyn Db,
} }
@ -50,7 +71,8 @@ impl<'db> UnionBuilder<'db> {
/// Collapse the union to a single type: `object`. /// Collapse the union to a single type: `object`.
fn collapse_to_object(mut self) -> Self { fn collapse_to_object(mut self) -> Self {
self.elements.clear(); self.elements.clear();
self.elements.push(Type::object(self.db)); self.elements
.push(UnionElement::Type(Type::object(self.db)));
self self
} }
@ -66,6 +88,76 @@ impl<'db> UnionBuilder<'db> {
} }
// Adding `Never` to a union is a no-op. // Adding `Never` to a union is a no-op.
Type::Never => {} Type::Never => {}
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
// add it to, or an existing element that is a super-type of string literals, which
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
// containing it.
Type::StringLiteral(literal) => {
let mut found = false;
for element in &mut self.elements {
match element {
UnionElement::StringLiterals(literals) => {
literals.insert(literal);
found = true;
break;
}
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
return self;
}
_ => {}
}
}
if !found {
self.elements
.push(UnionElement::StringLiterals(FxOrderSet::from_iter([
literal,
])));
}
}
// Same for bytes literals as for string literals, above.
Type::BytesLiteral(literal) => {
let mut found = false;
for element in &mut self.elements {
match element {
UnionElement::BytesLiterals(literals) => {
literals.insert(literal);
found = true;
break;
}
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
return self;
}
_ => {}
}
}
if !found {
self.elements
.push(UnionElement::BytesLiterals(FxOrderSet::from_iter([
literal,
])));
}
}
// And same for int literals as well.
Type::IntLiteral(literal) => {
let mut found = false;
for element in &mut self.elements {
match element {
UnionElement::IntLiterals(literals) => {
literals.insert(literal);
found = true;
break;
}
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
return self;
}
_ => {}
}
}
if !found {
self.elements
.push(UnionElement::IntLiterals(FxOrderSet::from_iter([literal])));
}
}
// Adding `object` to a union results in `object`. // Adding `object` to a union results in `object`.
ty if ty.is_object(self.db) => { ty if ty.is_object(self.db) => {
return self.collapse_to_object(); return self.collapse_to_object();
@ -81,8 +173,27 @@ impl<'db> UnionBuilder<'db> {
let mut to_remove = SmallVec::<[usize; 2]>::new(); let mut to_remove = SmallVec::<[usize; 2]>::new();
let ty_negated = ty.negate(self.db); let ty_negated = ty.negate(self.db);
for (index, element) in self.elements.iter().enumerate() { for (index, element) in self
if Some(*element) == bool_pair { .elements
.iter()
.map(|element| {
// For literals, the first element in the set can stand in for all the rest,
// since they all have the same super-types. SAFETY: a `UnionElement` of
// literal kind must always have at least one element in it.
match element {
UnionElement::IntLiterals(literals) => Type::IntLiteral(literals[0]),
UnionElement::StringLiterals(literals) => {
Type::StringLiteral(literals[0])
}
UnionElement::BytesLiterals(literals) => {
Type::BytesLiteral(literals[0])
}
UnionElement::Type(ty) => *ty,
}
})
.enumerate()
{
if Some(element) == bool_pair {
to_add = KnownClass::Bool.to_instance(self.db); to_add = KnownClass::Bool.to_instance(self.db);
to_remove.push(index); to_remove.push(index);
// The type we are adding is a BooleanLiteral, which doesn't have any // The type we are adding is a BooleanLiteral, which doesn't have any
@ -92,14 +203,14 @@ impl<'db> UnionBuilder<'db> {
break; break;
} }
if ty.is_same_gradual_form(*element) if ty.is_same_gradual_form(element)
|| ty.is_subtype_of(self.db, *element) || ty.is_subtype_of(self.db, element)
|| element.is_object(self.db) || element.is_object(self.db)
{ {
return self; return self;
} else if element.is_subtype_of(self.db, ty) { } else if element.is_subtype_of(self.db, ty) {
to_remove.push(index); to_remove.push(index);
} else if ty_negated.is_subtype_of(self.db, *element) { } else if ty_negated.is_subtype_of(self.db, element) {
// We add `ty` to the union. We just checked that `~ty` is a subtype of an existing `element`. // We add `ty` to the union. We just checked that `~ty` is a subtype of an existing `element`.
// This also means that `~ty | ty` is a subtype of `element | ty`, because both elements in the // This also means that `~ty | ty` is a subtype of `element | ty`, because both elements in the
// first union are subtypes of the corresponding elements in the second union. But `~ty | ty` is // first union are subtypes of the corresponding elements in the second union. But `~ty | ty` is
@ -111,13 +222,13 @@ impl<'db> UnionBuilder<'db> {
} }
} }
if let Some((&first, rest)) = to_remove.split_first() { if let Some((&first, rest)) = to_remove.split_first() {
self.elements[first] = to_add; self.elements[first] = UnionElement::Type(to_add);
// We iterate in descending order to keep remaining indices valid after `swap_remove`. // We iterate in descending order to keep remaining indices valid after `swap_remove`.
for &index in rest.iter().rev() { for &index in rest.iter().rev() {
self.elements.swap_remove(index); self.elements.swap_remove(index);
} }
} else { } else {
self.elements.push(to_add); self.elements.push(UnionElement::Type(to_add));
} }
} }
} }
@ -125,10 +236,25 @@ impl<'db> UnionBuilder<'db> {
} }
pub(crate) fn build(self) -> Type<'db> { pub(crate) fn build(self) -> Type<'db> {
match self.elements.len() { let mut types = vec![];
for element in self.elements {
match element {
UnionElement::IntLiterals(literals) => {
types.extend(literals.into_iter().map(Type::IntLiteral));
}
UnionElement::StringLiterals(literals) => {
types.extend(literals.into_iter().map(Type::StringLiteral));
}
UnionElement::BytesLiterals(literals) => {
types.extend(literals.into_iter().map(Type::BytesLiteral));
}
UnionElement::Type(ty) => types.push(ty),
}
}
match types.len() {
0 => Type::Never, 0 => Type::Never,
1 => self.elements[0], 1 => types[0],
_ => Type::Union(UnionType::new(self.db, self.elements.into_boxed_slice())), _ => Type::Union(UnionType::new(self.db, types.into_boxed_slice())),
} }
} }
} }