diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md index 748b94eda3..b576ce318d 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/integers.md @@ -19,7 +19,8 @@ reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True] ```py # TODO: implement lookup of `__eq__` on typeshed `int` stub. -def int_instance() -> int: ... +def int_instance() -> int: + return 42 reveal_type(1 == int_instance()) # revealed: @Todo reveal_type(9 < int_instance()) # revealed: bool diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md index d14bd93680..c80ee4d601 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -59,7 +59,8 @@ reveal_type(c >= d) # revealed: Literal[True] ```py def bool_instance() -> bool: ... -def int_instance() -> int: ... +def int_instance() -> int: + return 42 a = (bool_instance(),) b = (int_instance(),) @@ -159,7 +160,8 @@ reveal_type(a >= a) # revealed: @Todo "Membership Test Comparisons" refers to the operators `in` and `not in`. ```py -def int_instance() -> int: ... +def int_instance() -> int: + return 42 a = (1, 2) b = ((3, 4), (1, 2)) diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_elif_else.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_elif_else.md new file mode 100644 index 0000000000..7ac16f4b8d --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_elif_else.md @@ -0,0 +1,57 @@ +# Narrowing for conditionals with elif and else + +## Positive contributions become negative in elif-else blocks + +```py +def int_instance() -> int: + return 42 + +x = int_instance() + +if x == 1: + # cannot narrow; could be a subclass of `int` + reveal_type(x) # revealed: int +elif x == 2: + reveal_type(x) # revealed: int & ~Literal[1] +elif x != 3: + reveal_type(x) # revealed: int & ~Literal[1] & ~Literal[2] & ~Literal[3] +``` + +## Positive contributions become negative in elif-else blocks, with simplification + +```py +def bool_instance() -> bool: + return True + +x = 1 if bool_instance() else 2 if bool_instance() else 3 + +if x == 1: + # TODO should be Literal[1] + reveal_type(x) # revealed: Literal[1, 2, 3] +elif x == 2: + # TODO should be Literal[2] + reveal_type(x) # revealed: Literal[2, 3] +else: + reveal_type(x) # revealed: Literal[3] +``` + +## Multiple negative contributions using elif, with simplification + +```py +def bool_instance() -> bool: + return True + +x = 1 if bool_instance() else 2 if bool_instance() else 3 + +if x != 1: + reveal_type(x) # revealed: Literal[2, 3] +elif x != 2: + # TODO should be `Literal[1]` + reveal_type(x) # revealed: Literal[1, 3] +elif x == 3: + # TODO should be Never + reveal_type(x) # revealed: Literal[1, 2, 3] +else: + # TODO should be Never + reveal_type(x) # revealed: Literal[1, 2] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md index 8c042c75d0..b9cd22897d 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md @@ -11,6 +11,8 @@ x = None if flag else 1 if x is None: reveal_type(x) # revealed: None +else: + reveal_type(x) # revealed: Literal[1] reveal_type(x) # revealed: None | Literal[1] ``` @@ -30,6 +32,8 @@ y = x if flag else None if y is x: reveal_type(y) # revealed: A +else: + reveal_type(y) # revealed: A | None reveal_type(y) # revealed: A | None ``` @@ -50,4 +54,26 @@ reveal_type(y) # revealed: bool if y is x is False: # Interpreted as `(y is x) and (x is False)` reveal_type(x) # revealed: Literal[False] reveal_type(y) # revealed: bool +else: + # The negation of the clause above is (y is not x) or (x is not False) + # So we can't narrow the type of x or y here, because each arm of the `or` could be true + reveal_type(x) # revealed: bool + reveal_type(y) # revealed: bool +``` + +## `is` in elif clause + +```py +def bool_instance() -> bool: + return True + +x = None if bool_instance() else (1 if bool_instance() else True) + +reveal_type(x) # revealed: None | Literal[1] | Literal[True] +if x is None: + reveal_type(x) # revealed: None +elif x is True: + reveal_type(x) # revealed: Literal[True] +else: + reveal_type(x) # revealed: Literal[1] ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md index f7032d02af..f23bae8e1e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md @@ -13,6 +13,8 @@ x = None if flag else 1 if x is not None: reveal_type(x) # revealed: Literal[1] +else: + reveal_type(x) # revealed: None reveal_type(x) # revealed: None | Literal[1] ``` @@ -29,6 +31,8 @@ reveal_type(x) # revealed: bool if x is not False: reveal_type(x) # revealed: Literal[True] +else: + reveal_type(x) # revealed: Literal[False] ``` ## `is not` for non-singleton types @@ -43,6 +47,27 @@ y = 345 if x is not y: reveal_type(x) # revealed: Literal[345] +else: + reveal_type(x) # revealed: Literal[345] +``` + +## `is not` for other types + +```py +def bool_instance() -> bool: + return True + +class A: ... + +x = A() +y = x if bool_instance() else None + +if y is not x: + reveal_type(y) # revealed: A | None +else: + reveal_type(y) # revealed: A + +reveal_type(y) # revealed: A | None ``` ## `is not` in chained comparisons @@ -63,4 +88,10 @@ reveal_type(y) # revealed: bool if y is not x is not False: # Interpreted as `(y is not x) and (x is not False)` reveal_type(x) # revealed: Literal[True] reveal_type(y) # revealed: bool +else: + # The negation of the clause above is (y is x) or (x is False) + # So we can't narrow the type of x or y here, because each arm of the `or` could be true + + reveal_type(x) # revealed: bool + reveal_type(y) # revealed: bool ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md index 46492f8cca..cc0f79165e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md @@ -3,7 +3,8 @@ ## Multiple negative contributions ```py -def int_instance() -> int: ... +def int_instance() -> int: + return 42 x = int_instance() @@ -27,3 +28,29 @@ if x != 1: if x != 2: reveal_type(x) # revealed: Literal[3] ``` + +## elif-else blocks + +```py +def bool_instance() -> bool: + return True + +x = 1 if bool_instance() else 2 if bool_instance() else 3 + +if x != 1: + reveal_type(x) # revealed: Literal[2, 3] + if x == 2: + # TODO should be `Literal[2]` + reveal_type(x) # revealed: Literal[2, 3] + elif x == 3: + reveal_type(x) # revealed: Literal[3] + else: + reveal_type(x) # revealed: Never + +elif x != 2: + # TODO should be Literal[1] + reveal_type(x) # revealed: Literal[1, 3] +else: + # TODO should be Never + reveal_type(x) # revealed: Literal[1, 2, 3] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md index d0af94e9ea..3ad8ebcb68 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md @@ -11,6 +11,9 @@ x = None if flag else 1 if x != None: reveal_type(x) # revealed: Literal[1] +else: + # TODO should be None + reveal_type(x) # revealed: None | Literal[1] ``` ## `!=` for other singleton types @@ -24,6 +27,9 @@ x = True if flag else False if x != False: reveal_type(x) # revealed: Literal[True] +else: + # TODO should be Literal[False] + reveal_type(x) # revealed: bool ``` ## `x != y` where `y` is of literal type @@ -54,6 +60,25 @@ C = A if flag else B if C != A: reveal_type(C) # revealed: Literal[B] +else: + # TODO should be Literal[A] + reveal_type(C) # revealed: Literal[A, B] +``` + +## `x != y` where `y` has multiple single-valued options + +```py +def bool_instance() -> bool: + return True + +x = 1 if bool_instance() else 2 +y = 2 if bool_instance() else 3 + +if x != y: + reveal_type(x) # revealed: Literal[1, 2] +else: + # TODO should be Literal[2] + reveal_type(x) # revealed: Literal[1, 2] ``` ## `!=` for non-single-valued types @@ -74,3 +99,21 @@ y = int_instance() if x != y: reveal_type(x) # revealed: int | None ``` + +## Mix of single-valued and non-single-valued types + +```py +def int_instance() -> int: + return 42 + +def bool_instance() -> bool: + return True + +x = 1 if bool_instance() else 2 +y = 2 if bool_instance() else int_instance() + +if x != y: + reveal_type(x) # revealed: Literal[1, 2] +else: + reveal_type(x) # revealed: Literal[1, 2] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md index 5b5d984107..4b07d5648a 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md @@ -40,6 +40,8 @@ x = 1 if flag else "a" if isinstance(x, (int, str)): reveal_type(x) # revealed: Literal[1] | Literal["a"] +else: + reveal_type(x) # revealed: Never if isinstance(x, (int, bytes)): reveal_type(x) # revealed: Literal[1] @@ -51,6 +53,8 @@ if isinstance(x, (bytes, str)): # one of the possibilities: if isinstance(x, (int, object)): reveal_type(x) # revealed: Literal[1] | Literal["a"] +else: + reveal_type(x) # revealed: Never y = 1 if flag1 else "a" if flag2 else b"b" if isinstance(y, (int, str)): @@ -75,6 +79,8 @@ x = 1 if flag else "a" if isinstance(x, (bool, (bytes, int))): reveal_type(x) # revealed: Literal[1] +else: + reveal_type(x) # revealed: Literal["a"] ``` ## Class types @@ -82,6 +88,7 @@ if isinstance(x, (bool, (bytes, int))): ```py class A: ... class B: ... +class C: ... def get_object() -> object: ... @@ -91,6 +98,16 @@ if isinstance(x, A): reveal_type(x) # revealed: A if isinstance(x, B): reveal_type(x) # revealed: A & B + else: + reveal_type(x) # revealed: A & ~B + +if isinstance(x, (A, B)): + reveal_type(x) # revealed: A | B +elif isinstance(x, (A, C)): + reveal_type(x) # revealed: C & ~A & ~B +else: + # TODO: Should be simplified to ~A & ~B & ~C + reveal_type(x) # revealed: object & ~A & ~B & ~C ``` ## No narrowing for instances of `builtins.type` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md index bd807db36c..15330e4006 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md @@ -26,7 +26,8 @@ reveal_type(y) # revealed: Unknown ## Function return ```py -def int_instance() -> int: ... +def int_instance() -> int: + return 42 a = b"abcde"[int_instance()] # TODO: Support overloads... Should be `bytes` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md index 586eb4dfda..6987a95a70 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md @@ -23,7 +23,8 @@ reveal_type(b) # revealed: Unknown ## Function return ```py -def int_instance() -> int: ... +def int_instance() -> int: + return 42 a = "abcde"[int_instance()] # TODO: Support overloads... Should be `str` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index abbdd3f0c4..523682fcd3 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -27,7 +27,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::Db; -use super::constraint::{Constraint, PatternConstraint}; +use super::constraint::{Constraint, ConstraintNode, PatternConstraint}; use super::definition::{ AssignmentKind, DefinitionCategory, ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef, @@ -243,12 +243,23 @@ impl<'db> SemanticIndexBuilder<'db> { definition } - fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> { + fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> { let expression = self.add_standalone_expression(constraint_node); - self.current_use_def_map_mut() - .record_constraint(Constraint::Expression(expression)); + let constraint = Constraint { + node: ConstraintNode::Expression(expression), + is_positive: true, + }; + self.current_use_def_map_mut().record_constraint(constraint); - expression + constraint + } + + fn add_negated_constraint(&mut self, constraint: Constraint<'db>) { + self.current_use_def_map_mut() + .record_constraint(Constraint { + node: constraint.node, + is_positive: false, + }); } fn push_assignment(&mut self, assignment: CurrentAssignment<'db>) { @@ -285,7 +296,10 @@ impl<'db> SemanticIndexBuilder<'db> { countme::Count::default(), ); self.current_use_def_map_mut() - .record_constraint(Constraint::Pattern(pattern_constraint)); + .record_constraint(Constraint { + node: ConstraintNode::Pattern(pattern_constraint), + is_positive: true, + }); pattern_constraint } @@ -639,7 +653,8 @@ where ast::Stmt::If(node) => { self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); - self.add_expression_constraint(&node.test); + let constraint = self.add_expression_constraint(&node.test); + let mut constraints = vec![constraint]; self.visit_body(&node.body); let mut post_clauses: Vec = vec![]; for clause in &node.elif_else_clauses { @@ -649,7 +664,14 @@ where // we can only take an elif/else branch if none of the previous ones were // taken, so the block entry state is always `pre_if` self.flow_restore(pre_if.clone()); - self.visit_elif_else_clause(clause); + for constraint in &constraints { + self.add_negated_constraint(*constraint); + } + if let Some(elif_test) = &clause.test { + self.visit_expr(elif_test); + constraints.push(self.add_expression_constraint(elif_test)); + } + self.visit_body(&clause.body); } for post_clause_state in post_clauses { self.flow_merge(post_clause_state); diff --git a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs index 9659d5f82f..44b542f0e9 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs @@ -7,7 +7,13 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{FileScopeId, ScopeId}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum Constraint<'db> { +pub(crate) struct Constraint<'db> { + pub(crate) node: ConstraintNode<'db>, + pub(crate) is_positive: bool, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum ConstraintNode<'db> { Expression(Expression<'db>), Pattern(PatternConstraint<'db>), } diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index d445b694be..c932492f21 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -332,6 +332,11 @@ impl<'db> Type<'db> { .expect("Expected a Type::ModuleLiteral variant") } + #[must_use] + pub fn negate(&self, db: &'db dyn Db) -> Type<'db> { + IntersectionBuilder::new(db).add_negative(*self).build() + } + pub const fn into_union_type(self) -> Option> { match self { Type::Union(union_type) => Some(union_type), diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 7a17fc1085..0a6e772831 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -173,14 +173,10 @@ impl<'db> IntersectionBuilder<'db> { pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { // See comments above in `add_positive`; this is just the negated version. if let Type::Union(union) = ty { - union - .elements(self.db) - .iter() - .map(|elem| self.clone().add_negative(*elem)) - .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); - builder - }) + for elem in union.elements(self.db) { + self = self.add_negative(*elem); + } + self } else { for inner in &mut self.intersections { inner.add_negative(self.db, ty); @@ -667,6 +663,27 @@ mod tests { assert_eq!(ty, Type::IntLiteral(1)); } + #[test] + fn build_negative_union_de_morgan() { + let db = setup_db(); + + let union = UnionBuilder::new(&db) + .add(Type::IntLiteral(1)) + .add(Type::IntLiteral(2)) + .build(); + assert_eq!(union.display(&db).to_string(), "Literal[1, 2]"); + + let ty = IntersectionBuilder::new(&db).add_negative(union).build(); + + let expected = IntersectionBuilder::new(&db) + .add_negative(Type::IntLiteral(1)) + .add_negative(Type::IntLiteral(2)) + .build(); + + assert_eq!(ty.display(&db).to_string(), "~Literal[1] & ~Literal[2]"); + assert_eq!(ty, expected); + } + #[test] fn build_intersection_simplify_positive_type_and_positive_subtype() { let db = setup_db(); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 7f932827f1..f3c6b10691 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2415,7 +2415,6 @@ impl<'db> TypeInferenceBuilder<'db> { } else { None }; - let ty = bindings_ty(self.db, definitions, unbound_ty); if ty.is_unbound() { diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 93d9cb43bf..58a314523a 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -1,5 +1,5 @@ use crate::semantic_index::ast_ids::HasScopedAstId; -use crate::semantic_index::constraint::{Constraint, PatternConstraint}; +use crate::semantic_index::constraint::{Constraint, ConstraintNode, PatternConstraint}; use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; @@ -34,13 +34,19 @@ pub(crate) fn narrowing_constraint<'db>( constraint: Constraint<'db>, definition: Definition<'db>, ) -> Option> { - match constraint { - Constraint::Expression(expression) => { - all_narrowing_constraints_for_expression(db, expression) - .get(&definition.symbol(db)) - .copied() + match constraint.node { + ConstraintNode::Expression(expression) => { + if constraint.is_positive { + all_narrowing_constraints_for_expression(db, expression) + .get(&definition.symbol(db)) + .copied() + } else { + all_negative_narrowing_constraints_for_expression(db, expression) + .get(&definition.symbol(db)) + .copied() + } } - Constraint::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern) + ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern) .get(&definition.symbol(db)) .copied(), } @@ -51,7 +57,7 @@ fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternConstraint<'db>, ) -> NarrowingConstraints<'db> { - NarrowingConstraintsBuilder::new(db, Constraint::Pattern(pattern)).finish() + NarrowingConstraintsBuilder::new(db, ConstraintNode::Pattern(pattern), true).finish() } #[salsa::tracked(return_ref)] @@ -59,7 +65,15 @@ fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> NarrowingConstraints<'db> { - NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish() + NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), true).finish() +} + +#[salsa::tracked(return_ref)] +fn all_negative_narrowing_constraints_for_expression<'db>( + db: &'db dyn Db, + expression: Expression<'db>, +) -> NarrowingConstraints<'db> { + NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish() } /// Generate a constraint from the *type* of the second argument of an `isinstance` call. @@ -88,36 +102,39 @@ type NarrowingConstraints<'db> = FxHashMap>; struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, - constraint: Constraint<'db>, + constraint: ConstraintNode<'db>, + is_positive: bool, constraints: NarrowingConstraints<'db>, } impl<'db> NarrowingConstraintsBuilder<'db> { - fn new(db: &'db dyn Db, constraint: Constraint<'db>) -> Self { + fn new(db: &'db dyn Db, constraint: ConstraintNode<'db>, is_positive: bool) -> Self { Self { db, constraint, + is_positive, constraints: NarrowingConstraints::default(), } } fn finish(mut self) -> NarrowingConstraints<'db> { match self.constraint { - Constraint::Expression(expression) => self.evaluate_expression_constraint(expression), - Constraint::Pattern(pattern) => self.evaluate_pattern_constraint(pattern), + ConstraintNode::Expression(expression) => { + self.evaluate_expression_constraint(expression, self.is_positive); + } + ConstraintNode::Pattern(pattern) => self.evaluate_pattern_constraint(pattern), } - self.constraints.shrink_to_fit(); self.constraints } - fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) { + fn evaluate_expression_constraint(&mut self, expression: Expression<'db>, is_positive: bool) { match expression.node_ref(self.db).node() { ast::Expr::Compare(expr_compare) => { - self.add_expr_compare(expr_compare, expression); + self.add_expr_compare(expr_compare, expression, is_positive); } ast::Expr::Call(expr_call) => { - self.add_expr_call(expr_call, expression); + self.add_expr_call(expr_call, expression, is_positive); } _ => {} // TODO other test expression kinds } @@ -160,12 +177,17 @@ impl<'db> NarrowingConstraintsBuilder<'db> { fn scope(&self) -> ScopeId<'db> { match self.constraint { - Constraint::Expression(expression) => expression.scope(self.db), - Constraint::Pattern(pattern) => pattern.scope(self.db), + ConstraintNode::Expression(expression) => expression.scope(self.db), + ConstraintNode::Pattern(pattern) => pattern.scope(self.db), } } - fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>) { + fn add_expr_compare( + &mut self, + expr_compare: &ast::ExprCompare, + expression: Expression<'db>, + is_positive: bool, + ) { let ast::ExprCompare { range: _, left, @@ -177,6 +199,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> { // we have no symbol to narrow down the type of. return; } + if !is_positive && comparators.len() > 1 { + // We can't negate a constraint made by a multi-comparator expression, since we can't + // know which comparison part is the one being negated. + // For example, the negation of `x is 1 is y is 2`, would be `(x is not 1) or (y is not 1) or (y is not 2)` + // and that requires cross-symbol constraints, which we don't support yet. + return; + } let scope = self.scope(); let inference = infer_expression_types(self.db, expression); @@ -192,12 +221,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> { { // SAFETY: we should always have a symbol for every Name node. let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope)); - match op { + let rhs_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope)); + + match if is_positive { *op } else { op.negate() } { ast::CmpOp::IsNot => { - if comp_ty.is_singleton() { + if rhs_ty.is_singleton() { let ty = IntersectionBuilder::new(self.db) - .add_negative(comp_ty) + .add_negative(rhs_ty) .build(); self.constraints.insert(symbol, ty); } else { @@ -205,12 +235,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } ast::CmpOp::Is => { - self.constraints.insert(symbol, comp_ty); + self.constraints.insert(symbol, rhs_ty); } ast::CmpOp::NotEq => { - if comp_ty.is_single_valued(self.db) { + if rhs_ty.is_single_valued(self.db) { let ty = IntersectionBuilder::new(self.db) - .add_negative(comp_ty) + .add_negative(rhs_ty) .build(); self.constraints.insert(symbol, ty); } @@ -223,7 +253,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } - fn add_expr_call(&mut self, expr_call: &ast::ExprCall, expression: Expression<'db>) { + fn add_expr_call( + &mut self, + expr_call: &ast::ExprCall, + expression: Expression<'db>, + is_positive: bool, + ) { let scope = self.scope(); let inference = infer_expression_types(self.db, expression); @@ -242,7 +277,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> { // TODO: add support for PEP 604 union types on the right hand side: // isinstance(x, str | (int | float)) - if let Some(constraint) = generate_isinstance_constraint(self.db, &rhs_type) { + if let Some(mut constraint) = generate_isinstance_constraint(self.db, &rhs_type) + { + if !is_positive { + constraint = constraint.negate(self.db); + } self.constraints.insert(symbol, constraint); } } diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index d2e584279c..a35868d6a0 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -3071,6 +3071,22 @@ impl CmpOp { CmpOp::NotIn => "not in", } } + + #[must_use] + pub const fn negate(&self) -> Self { + match self { + CmpOp::Eq => CmpOp::NotEq, + CmpOp::NotEq => CmpOp::Eq, + CmpOp::Lt => CmpOp::GtE, + CmpOp::LtE => CmpOp::Gt, + CmpOp::Gt => CmpOp::LtE, + CmpOp::GtE => CmpOp::Lt, + CmpOp::Is => CmpOp::IsNot, + CmpOp::IsNot => CmpOp::Is, + CmpOp::In => CmpOp::NotIn, + CmpOp::NotIn => CmpOp::In, + } + } } impl fmt::Display for CmpOp {