[ty] Optimize IntersectionType for the common case of a single negated element (#22344)

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
Alex Waygood
2026-01-05 13:41:50 +00:00
committed by GitHub
parent 24dd149e03
commit f3dea6e5c9
2 changed files with 252 additions and 46 deletions

View File

@@ -12011,9 +12011,235 @@ pub struct IntersectionType<'db> {
/// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them
/// directly in intersections rather than as a separate type.
#[returns(ref)]
negative: FxOrderSet<Type<'db>>,
negative: NegativeIntersectionElements<'db>,
}
/// To avoid unnecessary allocations for the common case of 1 negative elements,
/// we use this enum to represent the negative elements of an intersection type.
///
/// It should otherwise have identical behavior to `FxOrderSet<Type<'db>>`.
///
/// Note that we do not try to maintain the invariant that length-0 collections
/// are always represented using `Self::Empty`, and that length-1 collections
/// are always represented using `Self::Single`: `Self::Multiple` is permitted
/// to have 0-1 elements in its wrapped data, and this could happen if you called
/// `Self::swap_remove` or `Self::swap_remove_index` on an instance that is
/// already the `Self::Multiple` variant. Maintaining the invariant that
/// 0-length or 1-length collections are always represented using `Self::Empty`
/// and `Self::Single` would add overhead to methods like `Self::swap_remove`,
/// and would have little value. At the point when you're calling that method, a
/// heap allocation has already taken place.
#[derive(Debug, Clone, get_size2::GetSize, salsa::Update, Default)]
pub enum NegativeIntersectionElements<'db> {
#[default]
Empty,
Single(Type<'db>),
Multiple(FxOrderSet<Type<'db>>),
}
impl<'db> NegativeIntersectionElements<'db> {
pub(crate) fn iter(&self) -> NegativeIntersectionElementsIterator<'_, 'db> {
match self {
Self::Empty => NegativeIntersectionElementsIterator::EmptyOrOne(None),
Self::Single(ty) => NegativeIntersectionElementsIterator::EmptyOrOne(Some(ty)),
Self::Multiple(set) => NegativeIntersectionElementsIterator::Multiple(set.iter()),
}
}
pub(crate) fn len(&self) -> usize {
match self {
Self::Empty => 0,
Self::Single(_) => 1,
Self::Multiple(set) => set.len(),
}
}
pub(crate) fn contains(&self, ty: &Type<'db>) -> bool {
match self {
Self::Empty => false,
Self::Single(existing) => existing == ty,
Self::Multiple(set) => set.contains(ty),
}
}
pub(crate) fn is_empty(&self) -> bool {
// See struct-level comment: we don't try to maintain the invariant that empty
// collections are representend as `Self::Empty`
self.len() == 0
}
/// Insert the type into the collection.
///
/// Returns `true` if the elements was newly added.
/// Returns `false` if the element was already present in the collection.
pub(crate) fn insert(&mut self, ty: Type<'db>) -> bool {
match self {
Self::Empty => {
*self = Self::Single(ty);
true
}
Self::Single(existing) => {
if ty != *existing {
*self = Self::Multiple(FxOrderSet::from_iter([*existing, ty]));
true
} else {
false
}
}
Self::Multiple(set) => set.insert(ty),
}
}
/// Shrink the capacity of the collection as much as possible.
pub(crate) fn shrink_to_fit(&mut self) {
match self {
Self::Empty | Self::Single(_) => {}
Self::Multiple(set) => set.shrink_to_fit(),
}
}
/// Sort the collection's types in place using the comparison function `cmp`.
pub(crate) fn sort_unstable_by(
&mut self,
compare: impl FnMut(&Type<'db>, &Type<'db>) -> std::cmp::Ordering,
) {
match self {
Self::Empty | Self::Single(_) => {}
Self::Multiple(set) => {
set.sort_unstable_by(compare);
}
}
}
/// Remove `ty` from the collection.
///
/// Returns `true` if `ty` was previously in the collection and has now been removed.
/// Returns `false` if `ty` was never present in the collection.
///
/// If `ty` was previously present in the collection,
/// the last element in the collection is popped off the end of the collection
/// and placed at the index where `ty` was previously, allowing this method to complete
/// in O(1) time (average).
pub(crate) fn swap_remove(&mut self, ty: &Type<'db>) -> bool {
match self {
Self::Empty => false,
Self::Single(existing) => {
if existing == ty {
*self = Self::Empty;
true
} else {
false
}
}
// See struct-level comment: we don't try to maintain the invariant that collections
// with size 0 or 1 are represented as `Empty` or `Single`.
Self::Multiple(set) => set.swap_remove(ty),
}
}
/// Remove the element at `index` from the collection.
///
/// The element is removed by swapping it with the last element
/// of the collection and popping it off, allowing this method to complete
/// in O(1) time (average).
pub(crate) fn swap_remove_index(&mut self, index: usize) -> Option<Type<'db>> {
match self {
Self::Empty => None,
Self::Single(existing) => {
if index == 0 {
let ty = *existing;
*self = Self::Empty;
Some(ty)
} else {
None
}
}
// See struct-level comment: we don't try to maintain the invariant that collections
// with size 0 or 1 are represented as `Empty` or `Single`.
Self::Multiple(set) => set.swap_remove_index(index),
}
}
/// Apply a transformation to all elements in this collection,
/// and return a new collection of the transformed elements.
fn map(&self, map_fn: impl Fn(&Type<'db>) -> Type<'db>) -> Self {
match self {
NegativeIntersectionElements::Empty => NegativeIntersectionElements::Empty,
NegativeIntersectionElements::Single(ty) => {
NegativeIntersectionElements::Single(map_fn(ty))
}
NegativeIntersectionElements::Multiple(set) => {
NegativeIntersectionElements::Multiple(set.iter().map(map_fn).collect())
}
}
}
/// Apply a fallible transformation to all elements in this collection,
/// and return a new collection of the transformed elements.
///
/// Returns `None` if `map_fn` fails for any element in the collection.
fn try_map(&self, map_fn: impl Fn(&Type<'db>) -> Option<Type<'db>>) -> Option<Self> {
match self {
NegativeIntersectionElements::Empty => Some(NegativeIntersectionElements::Empty),
NegativeIntersectionElements::Single(ty) => {
map_fn(ty).map(NegativeIntersectionElements::Single)
}
NegativeIntersectionElements::Multiple(set) => {
Some(NegativeIntersectionElements::Multiple(
set.iter().map(map_fn).collect::<Option<_>>()?,
))
}
}
}
}
impl<'a, 'db> IntoIterator for &'a NegativeIntersectionElements<'db> {
type Item = &'a Type<'db>;
type IntoIter = NegativeIntersectionElementsIterator<'a, 'db>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl PartialEq for NegativeIntersectionElements<'_> {
fn eq(&self, other: &Self) -> bool {
// Same implementation as `OrderSet::eq`
self.len() == other.len() && self.iter().eq(other)
}
}
impl Eq for NegativeIntersectionElements<'_> {}
impl std::hash::Hash for NegativeIntersectionElements<'_> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// Same implementation as `OrderSet::hash`
self.len().hash(state);
for value in self {
value.hash(state);
}
}
}
#[derive(Debug)]
pub enum NegativeIntersectionElementsIterator<'a, 'db> {
EmptyOrOne(Option<&'a Type<'db>>),
Multiple(ordermap::set::Iter<'a, Type<'db>>),
}
impl<'a, 'db> Iterator for NegativeIntersectionElementsIterator<'a, 'db> {
type Item = &'a Type<'db>;
fn next(&mut self) -> Option<Self::Item> {
match self {
NegativeIntersectionElementsIterator::EmptyOrOne(opt) => opt.take(),
NegativeIntersectionElementsIterator::Multiple(iter) => iter.next(),
}
}
}
impl std::iter::FusedIterator for NegativeIntersectionElementsIterator<'_, '_> {}
// The Salsa heap is tracked separately.
impl get_size2::GetSize for IntersectionType<'_> {}
@@ -12051,25 +12277,18 @@ impl<'db> IntersectionType<'db> {
}
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
fn normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
visitor: &NormalizedVisitor<'db>,
) -> FxOrderSet<Type<'db>> {
let mut elements: FxOrderSet<Type<'db>> = elements
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect();
let mut positive: FxOrderSet<Type<'db>> = self
.positive(db)
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect();
elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
elements
}
let mut negative = self.negative(db).map(|ty| ty.normalized_impl(db, visitor));
IntersectionType::new(
db,
normalized_set(db, self.positive(db), visitor),
normalized_set(db, self.negative(db), visitor),
)
positive.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
negative.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
IntersectionType::new(db, positive, negative)
}
pub(crate) fn recursive_type_normalized_impl(
@@ -12078,42 +12297,29 @@ impl<'db> IntersectionType<'db> {
div: Type<'db>,
nested: bool,
) -> Option<Self> {
fn opt_normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
div: Type<'db>,
nested: bool,
) -> Option<FxOrderSet<Type<'db>>> {
elements
let positive = if nested {
self.positive(db)
.iter()
.map(|ty| ty.recursive_type_normalized_impl(db, div, nested))
.collect()
}
fn normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
div: Type<'db>,
nested: bool,
) -> FxOrderSet<Type<'db>> {
elements
.collect::<Option<FxOrderSet<Type<'db>>>>()?
} else {
self.positive(db)
.iter()
.map(|ty| {
ty.recursive_type_normalized_impl(db, div, nested)
.unwrap_or(div)
})
.collect()
}
let positive = if nested {
opt_normalized_set(db, self.positive(db), div, nested)?
} else {
normalized_set(db, self.positive(db), div, nested)
};
let negative = if nested {
opt_normalized_set(db, self.negative(db), div, nested)?
self.negative(db)
.try_map(|ty| ty.recursive_type_normalized_impl(db, div, nested))?
} else {
normalized_set(db, self.negative(db), div, nested)
self.negative(db).map(|ty| {
ty.recursive_type_normalized_impl(db, div, nested)
.unwrap_or(div)
})
};
Some(IntersectionType::new(db, positive, negative))

View File

@@ -40,8 +40,8 @@
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::type_ordering::union_or_intersection_elements_ordering;
use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
TypeVarBoundOrConstraints, UnionType,
BytesLiteralType, IntersectionType, KnownClass, NegativeIntersectionElements,
StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType,
};
use crate::{Db, FxOrderSet};
use rustc_hash::FxHashSet;
@@ -945,7 +945,7 @@ impl<'db> IntersectionBuilder<'db> {
#[derive(Debug, Clone, Default)]
struct InnerIntersectionBuilder<'db> {
positive: FxOrderSet<Type<'db>>,
negative: FxOrderSet<Type<'db>>,
negative: NegativeIntersectionElements<'db>,
}
impl<'db> InnerIntersectionBuilder<'db> {