[ty] Optimize union building for unions with many enum-literal members (#22363)

This commit is contained in:
Alex Waygood
2026-01-08 10:50:04 +00:00
committed by GitHub
parent 7319c37f4e
commit eeac2bd3ee
2 changed files with 637 additions and 37 deletions

File diff suppressed because one or more lines are too long

View File

@@ -39,28 +39,42 @@
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::type_ordering::union_or_intersection_elements_ordering;
use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, NegativeIntersectionElements,
StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType,
BytesLiteralType, ClassLiteral, EnumLiteralType, IntersectionType, KnownClass,
NegativeIntersectionElements, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType,
};
use crate::{Db, FxOrderSet};
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LiteralKind {
enum LiteralKind<'db> {
Int,
String,
Bytes,
Enum { enum_class: ClassLiteral<'db> },
}
impl<'db> Type<'db> {
/// Return `true` if this type can be a supertype of some literals of `kind` and not others.
fn splits_literals(self, db: &'db dyn Db, kind: LiteralKind) -> bool {
match (self, kind) {
// Note that as of 2026-01-04, `AlwaysFalsy` and `AlwaysTruthy` never split
// enum literals, but that could change in the future. `Literal[Foo.X]` could
// plausibly be understood by ty as a subtype of `AlwaysFalsy` in the following
// snippet, because `Foo` is an IntEnum that does not override `__bool__` and
// `Foo.X` has a falsy value whereas `Foo.Y` does not:
//
// ```py
// class Foo(enum.IntEnum):
// X = 0
// Y = 1
// ```
(Type::AlwaysFalsy | Type::AlwaysTruthy, _) => true,
(Type::StringLiteral(_), LiteralKind::String) => true,
(Type::BytesLiteral(_), LiteralKind::Bytes) => true,
(Type::IntLiteral(_), LiteralKind::Int) => true,
(Type::EnumLiteral(enum_literal), LiteralKind::Enum { enum_class }) => {
enum_literal.enum_class(db) == enum_class
}
(Type::Intersection(intersection), _) => {
intersection
.positive(db)
@@ -85,17 +99,14 @@ enum UnionElement<'db> {
IntLiterals(FxOrderSet<i64>),
StringLiterals(FxOrderSet<StringLiteralType<'db>>),
BytesLiterals(FxOrderSet<BytesLiteralType<'db>>),
EnumLiterals {
enum_class: ClassLiteral<'db>,
literals: FxOrderSet<EnumLiteralType<'db>>,
},
Type(Type<'db>),
}
impl<'db> UnionElement<'db> {
const fn to_type_element(&self) -> Option<Type<'db>> {
match self {
UnionElement::Type(ty) => Some(*ty),
_ => None,
}
}
/// Try reducing this `UnionElement` given the presence in the same union of `other_type`.
fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> {
let mut other_type_negated_cache = None;
@@ -160,6 +171,20 @@ impl<'db> UnionElement<'db> {
!Type::BytesLiteral(literals[0]).is_redundant_with(db, other_type)
}
}
UnionElement::EnumLiterals {
enum_class,
literals,
} => {
let literal_kind = LiteralKind::Enum {
enum_class: *enum_class,
};
if other_type.splits_literals(db, literal_kind) {
literals.retain(|literal| should_retain_type(Type::EnumLiteral(*literal)));
!literals.is_empty()
} else {
!Type::EnumLiteral(literals[0]).is_redundant_with(db, other_type)
}
}
UnionElement::Type(existing) => return ReduceResult::Type(*existing),
};
@@ -205,19 +230,26 @@ impl RecursivelyDefined {
}
}
/// If the value is defined recursively, widening is performed from fewer literal elements, resulting in faster convergence of the fixed-point iteration.
/// If the value is defined recursively, widening is performed from fewer literal elements,
/// resulting in faster convergence of the fixed-point iteration.
const MAX_RECURSIVE_UNION_LITERALS: usize = 10;
/// If the value is defined non-recursively, the fixed-point iteration will converge in one go,
/// so in principle we can have as many literal elements as we want, but to avoid unintended huge computational loads, we limit it to 256.
/// so in principle we can have as many literal elements as we want,
/// but to avoid unintended huge computational loads, we limit it to 256.
const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256;
/// However, we set a much larger limit for enum literals than for other kinds of literals.
/// Huge enums are not uncommon (especially in generated code), and it's annoying
/// if reachability analysis etc. fails when analysing these enums.
const MAX_NON_RECURSIVE_UNION_ENUM_LITERALS: usize = 8192;
pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>,
db: &'db dyn Db,
unpack_aliases: bool,
order_elements: bool,
// This is enabled when joining types in a `cycle_recovery` function.
// Since a cycle cannot be created within a `cycle_recovery` function, execution of `is_redundant_with` is skipped.
/// This is enabled when joining types in a `cycle_recovery` function.
/// Since a cycle cannot be created within a `cycle_recovery` function,
/// execution of `is_redundant_with` is skipped.
cycle_recovery: bool,
recursively_defined: RecursivelyDefined,
}
@@ -280,6 +312,9 @@ impl<'db> UnionBuilder<'db> {
UnionElement::BytesLiterals(_) => {
replace_with.push(KnownClass::Bytes.to_instance(self.db));
}
UnionElement::EnumLiterals { literals, .. } => {
replace_with.push(literals[0].enum_class_instance(self.db));
}
UnionElement::Type(_) => {}
}
}
@@ -327,6 +362,7 @@ impl<'db> UnionBuilder<'db> {
UnionElement::IntLiterals(literals) => acc + literals.len(),
UnionElement::StringLiterals(literals) => acc + literals.len(),
UnionElement::BytesLiterals(literals) => acc + literals.len(),
UnionElement::EnumLiterals { literals, .. } => acc + literals.len(),
UnionElement::Type(_) => acc,
});
if should_widen(literals, self.recursively_defined) {
@@ -492,32 +528,78 @@ impl<'db> UnionBuilder<'db> {
let metadata =
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
let enum_members_in_union = self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.filter_map(Type::as_enum_literal)
.map(|literal| literal.name(self.db))
.chain(std::iter::once(enum_member_to_add.name(self.db)))
.collect::<FxHashSet<_>>();
let all_members_are_in_union = metadata
.members
.keys()
.all(|name| enum_members_in_union.contains(name));
if all_members_are_in_union {
if metadata.members.len() == 1 {
self.add_in_place_impl(
enum_member_to_add.enum_class_instance(self.db),
seen_aliases,
);
} else if !self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty))
{
self.push_type(Type::EnumLiteral(enum_member_to_add), seen_aliases);
return;
}
let mut found = None;
let mut to_remove = None;
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::EnumLiterals {
enum_class: existing_enum_class,
literals,
} => {
if *existing_enum_class != enum_class {
continue;
}
// See the doc-comment above `MAX_NON_RECURSIVE_UNION_ENUM_LITERALS`
// for why we avoid using the `should_widen` closure here.
let enum_literals_limit =
if self.recursively_defined.is_yes() && cycle_recovery {
MAX_RECURSIVE_UNION_LITERALS
} else {
MAX_NON_RECURSIVE_UNION_ENUM_LITERALS
};
if literals.len() >= enum_literals_limit {
let replace_with = literals[0].enum_class_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
}
found = Some(literals);
continue;
}
UnionElement::Type(existing) => {
if ty.is_redundant_with(self.db, *existing) {
return;
}
// e.g. `existing` could be `Literal[Foo.X] & Any`,
// and `ty` could be `Literal[Foo.X]`
if existing.is_redundant_with(self.db, ty) {
to_remove = Some(index);
continue;
}
if ty_negated().is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
return;
}
}
_ => {}
}
}
if let Some(found) = found {
let newly_added = found.insert(enum_member_to_add);
if newly_added && found.len() == metadata.members.len() {
self.add_in_place_impl(
enum_member_to_add.enum_class_instance(self.db),
seen_aliases,
);
return;
}
} else {
self.elements.push(UnionElement::EnumLiterals {
enum_class,
literals: FxOrderSet::from_iter([enum_member_to_add]),
});
}
if let Some(index) = to_remove {
self.elements.swap_remove(index);
}
}
// Adding `object` to a union results in `object`.
@@ -636,6 +718,9 @@ impl<'db> UnionBuilder<'db> {
UnionElement::BytesLiterals(literals) => {
types.extend(literals.into_iter().map(Type::BytesLiteral));
}
UnionElement::EnumLiterals { literals, .. } => {
types.extend(literals.into_iter().map(Type::EnumLiteral));
}
UnionElement::Type(ty) => types.push(ty),
}
}