diff --git a/crates/ty_python_semantic/resources/mdtest/conditional/match.md b/crates/ty_python_semantic/resources/mdtest/conditional/match.md index 492ca8ef53..729483fcf5 100644 --- a/crates/ty_python_semantic/resources/mdtest/conditional/match.md +++ b/crates/ty_python_semantic/resources/mdtest/conditional/match.md @@ -386,3 +386,52 @@ def _(target: int, flag: NotBoolable): reveal_type(y) # revealed: Literal[1, 2, 3] ``` + +## Matching on enum | None without covering None + +When matching on a union of an enum and None, code after the match should still be reachable if None +is not covered by any case, even when all enum members are covered. + +```py +from enum import Enum + +class Answer(Enum): + YES = 1 + NO = 2 + +def _(answer: Answer | None): + y = 0 + match answer: + case Answer.YES: + y = 1 + case Answer.NO: + y = 2 + + # The match is not exhaustive because None is not covered, + # so y could still be 0 + reveal_type(y) # revealed: Literal[0, 1, 2] + +def _(answer: Answer | None): + match answer: + case Answer.YES: + return 1 + case Answer.NO: + return 2 + + # Code here is reachable because None is not covered + reveal_type(answer) # revealed: None + return 3 + +class Foo: ... + +def _(answer: Answer | None): + match answer: + case Answer.YES: + return + case Answer.NO: + return + + # New assignments after the match should not be `Never` + x = Foo() + reveal_type(x) # revealed: Foo +``` diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 132c5b077d..36a258f20d 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -810,13 +810,6 @@ impl<'db> IntersectionBuilder<'db> { ty: Type<'db>, seen_aliases: &mut Vec>, ) -> Self { - let contains_enum = |enum_instance| { - self.intersections - .iter() - .flat_map(|intersection| &intersection.positive) - .any(|ty| *ty == enum_instance) - }; - // See comments above in `add_positive`; this is just the negated version. match ty { Type::TypeAlias(alias) => { @@ -871,12 +864,27 @@ impl<'db> IntersectionBuilder<'db> { }, ) } - Type::EnumLiteral(enum_literal) - if contains_enum(enum_literal.enum_class_instance(self.db)) => - { - let db = self.db; - self.add_positive_impl( - UnionType::from_elements( + Type::EnumLiteral(enum_literal) => { + let enum_instance = enum_literal.enum_class_instance(self.db); + + // Partition intersections into those that contain the enum instance and those that don't. + // For intersections containing the enum, we need to expand to remaining members. + // For others, we just add the negative normally. + let (enum_intersections, other_intersections): (Vec<_>, Vec<_>) = self + .intersections + .into_iter() + .partition(|inner| inner.positive.contains(&enum_instance)); + + if enum_intersections.is_empty() { + // No inner intersection contains the enum, just add negative normally + self.intersections = other_intersections; + for inner in &mut self.intersections { + inner.add_negative(self.db, ty); + } + self + } else { + let db = self.db; + let remaining_members = UnionType::from_elements( db, enum_member_literals( db, @@ -884,9 +892,32 @@ impl<'db> IntersectionBuilder<'db> { Some(enum_literal.name(db)), ) .expect("Calling `enum_member_literals` on an enum class"), - ), - seen_aliases, - ) + ); + + // For enum-containing intersections, add the remaining members as positive + let mut enum_builder = IntersectionBuilder { + db, + order_elements: self.order_elements, + intersections: enum_intersections, + } + .add_positive_impl(remaining_members, seen_aliases); + + // For non-enum intersections, just add the negative normally + let mut other_builder = IntersectionBuilder { + db, + order_elements: self.order_elements, + intersections: other_intersections, + }; + for inner in &mut other_builder.intersections { + inner.add_negative(db, ty); + } + + // Combine the results + enum_builder + .intersections + .extend(other_builder.intersections); + enum_builder + } } _ => { for inner in &mut self.intersections {