diff --git a/crates/red_knot_python_semantic/resources/mdtest/union_types.md b/crates/red_knot_python_semantic/resources/mdtest/union_types.md index 44d4d93d1d..45bbf07fac 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/union_types.md +++ b/crates/red_knot_python_semantic/resources/mdtest/union_types.md @@ -166,3 +166,46 @@ def _( reveal_type(i1) # revealed: P & Q reveal_type(i2) # revealed: P & Q ``` + +## Unions of literals with `AlwaysTruthy` and `AlwaysFalsy` + +```py +from typing import Literal +from knot_extensions import AlwaysTruthy, AlwaysFalsy + +type strings = Literal["foo", ""] +type ints = Literal[0, 1] +type bytes = Literal[b"foo", b""] + +def _( + strings_or_truthy: strings | AlwaysTruthy, + truthy_or_strings: AlwaysTruthy | strings, + strings_or_falsy: strings | AlwaysFalsy, + falsy_or_strings: AlwaysFalsy | strings, + ints_or_truthy: ints | AlwaysTruthy, + truthy_or_ints: AlwaysTruthy | ints, + ints_or_falsy: ints | AlwaysFalsy, + falsy_or_ints: AlwaysFalsy | ints, + bytes_or_truthy: bytes | AlwaysTruthy, + truthy_or_bytes: AlwaysTruthy | bytes, + bytes_or_falsy: bytes | AlwaysFalsy, + falsy_or_bytes: AlwaysFalsy | bytes, +): + reveal_type(strings_or_truthy) # revealed: Literal[""] | AlwaysTruthy + reveal_type(truthy_or_strings) # revealed: AlwaysTruthy | Literal[""] + + reveal_type(strings_or_falsy) # revealed: Literal["foo"] | AlwaysFalsy + reveal_type(falsy_or_strings) # revealed: AlwaysFalsy | Literal["foo"] + + reveal_type(ints_or_truthy) # revealed: Literal[0] | AlwaysTruthy + reveal_type(truthy_or_ints) # revealed: AlwaysTruthy | Literal[0] + + reveal_type(ints_or_falsy) # revealed: Literal[1] | AlwaysFalsy + reveal_type(falsy_or_ints) # revealed: AlwaysFalsy | Literal[1] + + reveal_type(bytes_or_truthy) # revealed: Literal[b""] | AlwaysTruthy + reveal_type(truthy_or_bytes) # revealed: AlwaysTruthy | Literal[b""] + + reveal_type(bytes_or_falsy) # revealed: Literal[b"foo"] | AlwaysFalsy + reveal_type(falsy_or_bytes) # revealed: AlwaysFalsy | Literal[b"foo"] +``` diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 411c080257..f51f59b871 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -51,6 +51,67 @@ enum UnionElement<'db> { Type(Type<'db>), } +impl<'db> UnionElement<'db> { + /// Try reducing this `UnionElement` given the presence in the same union of `other_type`. + /// + /// If this `UnionElement` is a group of literals, filter the literals present if needed and + /// return `ReduceResult::KeepIf` with a boolean value indicating whether the remaining group + /// of literals should be kept in the union + /// + /// If this `UnionElement` is some other type, return `ReduceResult::Type` so `UnionBuilder` + /// can perform more complex checks on it. + fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> { + // `AlwaysTruthy` and `AlwaysFalsy` are the only types which can be a supertype of only + // _some_ literals of the same kind, so we need to walk the full set in this case. + let needs_filter = matches!(other_type, Type::AlwaysTruthy | Type::AlwaysFalsy); + match self { + UnionElement::IntLiterals(literals) => { + ReduceResult::KeepIf(if needs_filter { + literals.retain(|literal| { + !Type::IntLiteral(*literal).is_subtype_of(db, other_type) + }); + !literals.is_empty() + } else { + // SAFETY: All `UnionElement` literal kinds must always be non-empty + !Type::IntLiteral(literals[0]).is_subtype_of(db, other_type) + }) + } + UnionElement::StringLiterals(literals) => { + ReduceResult::KeepIf(if needs_filter { + literals.retain(|literal| { + !Type::StringLiteral(*literal).is_subtype_of(db, other_type) + }); + !literals.is_empty() + } else { + // SAFETY: All `UnionElement` literal kinds must always be non-empty + !Type::StringLiteral(literals[0]).is_subtype_of(db, other_type) + }) + } + UnionElement::BytesLiterals(literals) => { + ReduceResult::KeepIf(if needs_filter { + literals.retain(|literal| { + !Type::BytesLiteral(*literal).is_subtype_of(db, other_type) + }); + !literals.is_empty() + } else { + // SAFETY: All `UnionElement` literal kinds must always be non-empty + !Type::BytesLiteral(literals[0]).is_subtype_of(db, other_type) + }) + } + UnionElement::Type(existing) => ReduceResult::Type(*existing), + } + } +} + +enum ReduceResult<'db> { + /// Reduction of this `UnionElement` is complete; keep it in the union if the nested + /// boolean is true, eliminate it from the union if false. + KeepIf(bool), + /// The given `Type` can stand-in for the entire `UnionElement` for further union + /// simplification checks. + Type(Type<'db>), +} + // TODO increase this once we extend `UnionElement` throughout all union/intersection // representations, so that we can make large unions of literals fast in all operations. const MAX_UNION_LITERALS: usize = 200; @@ -197,27 +258,17 @@ impl<'db> UnionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 2]>::new(); let ty_negated = ty.negate(self.db); - for (index, element) in self - .elements - .iter() - .map(|element| { - // For literals, the first element in the set can stand in for all the rest, - // since they all have the same super-types. SAFETY: a `UnionElement` of - // literal kind must always have at least one element in it. - match element { - UnionElement::IntLiterals(literals) => Type::IntLiteral(literals[0]), - UnionElement::StringLiterals(literals) => { - Type::StringLiteral(literals[0]) + for (index, element) in self.elements.iter_mut().enumerate() { + let element_type = match element.try_reduce(self.db, ty) { + ReduceResult::KeepIf(keep) => { + if !keep { + to_remove.push(index); } - UnionElement::BytesLiterals(literals) => { - Type::BytesLiteral(literals[0]) - } - UnionElement::Type(ty) => *ty, + continue; } - }) - .enumerate() - { - if Some(element) == bool_pair { + ReduceResult::Type(ty) => ty, + }; + if Some(element_type) == bool_pair { to_add = KnownClass::Bool.to_instance(self.db); to_remove.push(index); // The type we are adding is a BooleanLiteral, which doesn't have any @@ -227,14 +278,14 @@ impl<'db> UnionBuilder<'db> { break; } - if ty.is_same_gradual_form(element) - || ty.is_subtype_of(self.db, element) - || element.is_object(self.db) + if ty.is_same_gradual_form(element_type) + || ty.is_subtype_of(self.db, element_type) + || element_type.is_object(self.db) { return; - } else if element.is_subtype_of(self.db, ty) { + } else if element_type.is_subtype_of(self.db, ty) { to_remove.push(index); - } else if ty_negated.is_subtype_of(self.db, element) { + } else if ty_negated.is_subtype_of(self.db, element_type) { // We add `ty` to the union. We just checked that `~ty` is a subtype of an existing `element`. // This also means that `~ty | ty` is a subtype of `element | ty`, because both elements in the // first union are subtypes of the corresponding elements in the second union. But `~ty | ty` is