diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index da8e2668d8..5b7593837b 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -25,8 +25,10 @@ //! * No type in an intersection can be a supertype of any other type in the intersection (just //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. + use crate::types::{IntersectionType, Type, UnionType}; use crate::{Db, FxOrderSet}; +use ordermap::set::MutableValues; use super::builtins_symbol_ty_by_name; @@ -63,23 +65,31 @@ impl<'db> UnionBuilder<'db> { /// - TODO For enums `E` with members `X1`,...,`Xn`, replaces /// `Literal[E.X1,...,E.Xn]` with `E`. fn simplify(&mut self) { - if self - .elements - .is_superset(&[Type::BooleanLiteral(true), Type::BooleanLiteral(false)].into()) - { - let bool_ty = builtins_symbol_ty_by_name(self.db, "bool"); - self.elements.remove(&Type::BooleanLiteral(true)); - self.elements.remove(&Type::BooleanLiteral(false)); - self.elements.insert(bool_ty); + if let Some(true_index) = self.elements.get_index_of(&Type::BooleanLiteral(true)) { + if self.elements.contains(&Type::BooleanLiteral(false)) { + *self.elements.get_index_mut2(true_index).unwrap() = + builtins_symbol_ty_by_name(self.db, "bool"); + self.elements.remove(&Type::BooleanLiteral(false)); + } } } pub(crate) fn build(mut self) -> Type<'db> { - self.simplify(); match self.elements.len() { 0 => Type::Never, 1 => self.elements[0], - _ => Type::Union(UnionType::new(self.db, self.elements)), + _ => { + self.simplify(); + + match self.elements.len() { + 0 => Type::Never, + 1 => self.elements[0], + _ => { + self.elements.shrink_to_fit(); + Type::Union(UnionType::new(self.db, self.elements)) + } + } + } } } } @@ -360,7 +370,7 @@ mod tests { panic!("expected a union"); }; - assert_eq!(union.elements_vec(&db), &[t3, bool_ty]); + assert_eq!(union.elements_vec(&db), &[bool_ty, t3]); } #[test]