From 3aa91a853e5dbbb7907164233154db268eefdb1b Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 11 Jul 2025 13:43:35 +0100 Subject: [PATCH] add tests for bools and make helper method private --- .../resources/mdtest/binary/booleans.md | 19 +++++ .../resources/mdtest/binary/integers.md | 2 +- crates/ty_python_semantic/src/types/infer.rs | 84 +++++++++---------- 3 files changed, 60 insertions(+), 45 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/binary/booleans.md b/crates/ty_python_semantic/resources/mdtest/binary/booleans.md index 5017a93178..0a70ccfa6c 100644 --- a/crates/ty_python_semantic/resources/mdtest/binary/booleans.md +++ b/crates/ty_python_semantic/resources/mdtest/binary/booleans.md @@ -57,6 +57,25 @@ reveal_type(a ^ a) # revealed: Literal[False] reveal_type(a ^ b) # revealed: Literal[True] reveal_type(b ^ a) # revealed: Literal[True] reveal_type(b ^ b) # revealed: Literal[False] + +# left-shift +reveal_type(a << a) # revealed: Literal[2] +reveal_type(a << b) # revealed: Literal[1] +reveal_type(b << a) # revealed: Literal[0] +reveal_type(b << b) # revealed: Literal[0] +reveal_type(True << 100) # revealed: int + +# error: [literal-math-error] "Cannot left shift object of type `Literal[True]` by a negative value" +reveal_type(True << -1) # revealed: int + +# right-shift +reveal_type(a >> a) # revealed: Literal[0] +reveal_type(a >> b) # revealed: Literal[1] +reveal_type(b >> a) # revealed: Literal[0] +reveal_type(b >> b) # revealed: Literal[0] + +# error: [literal-math-error] "Cannot right shift object of type `Literal[False]` by a negative value" +reveal_type(False >> -1) # revealed: int ``` ## Arithmetic with a variable diff --git a/crates/ty_python_semantic/resources/mdtest/binary/integers.md b/crates/ty_python_semantic/resources/mdtest/binary/integers.md index c81c8fdb08..20e6b4aca9 100644 --- a/crates/ty_python_semantic/resources/mdtest/binary/integers.md +++ b/crates/ty_python_semantic/resources/mdtest/binary/integers.md @@ -151,7 +151,7 @@ reveal_type(MyInt(3) / 0) # revealed: int | float ## Bit-shifting -Literal artithmetic is supported for bit-shifting operations on `int`s: +Literal arithmetic is supported for bit-shifting operations on `int`s: ```py reveal_type(42 << 3) # revealed: Literal[336] diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index e3b384d8f9..1248fe7316 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -1512,48 +1512,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - /// Emit a diagnostic if the given type cannot be divided by zero, or is shifted by a negative - /// value. - /// - /// Expects the resolved type of the left side of the binary expression. - fn check_bad_rhs(&mut self, node: AnyNodeRef<'_>, op: ast::Operator, left: Type<'db>) -> bool { - let lhs_int = match left { - Type::BooleanLiteral(_) | Type::IntLiteral(_) => true, - Type::NominalInstance(instance) - if matches!( - instance.class.known(self.db()), - Some(KnownClass::Int | KnownClass::Bool) - ) => - { - true - } - Type::NominalInstance(instance) - if matches!(instance.class.known(self.db()), Some(KnownClass::Float)) => - { - false - } - _ => return false, - }; - - let (op, by_what) = match (op, lhs_int) { - (ast::Operator::Div, _) => ("divide", "by zero"), - (ast::Operator::FloorDiv, _) => ("floor divide", "by zero"), - (ast::Operator::Mod, _) => ("reduce", "modulo zero"), - (ast::Operator::LShift, true) => ("left shift", "by a negative value"), - (ast::Operator::RShift, true) => ("right shift", "by a negative value"), - _ => return false, - }; - - if let Some(builder) = self.context.report_lint(&LITERAL_MATH_ERROR, node) { - builder.into_diagnostic(format_args!( - "Cannot {op} object of type `{}` {by_what}", - left.display(self.db()) - )); - } - - true - } - fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { debug_assert!( binding @@ -6474,6 +6432,44 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { right_ty: Type<'db>, op: ast::Operator, ) -> Option> { + let check_bad_rhs = || { + let lhs_int = match left_ty { + Type::BooleanLiteral(_) | Type::IntLiteral(_) => true, + Type::NominalInstance(instance) + if matches!( + instance.class.known(self.db()), + Some(KnownClass::Int | KnownClass::Bool) + ) => + { + true + } + Type::NominalInstance(instance) + if matches!(instance.class.known(self.db()), Some(KnownClass::Float)) => + { + false + } + _ => return false, + }; + + let (op, by_what) = match (op, lhs_int) { + (ast::Operator::Div, _) => ("divide", "by zero"), + (ast::Operator::FloorDiv, _) => ("floor divide", "by zero"), + (ast::Operator::Mod, _) => ("reduce", "modulo zero"), + (ast::Operator::LShift, true) => ("left shift", "by a negative value"), + (ast::Operator::RShift, true) => ("right shift", "by a negative value"), + _ => return false, + }; + + if let Some(builder) = self.context.report_lint(&LITERAL_MATH_ERROR, node) { + builder.into_diagnostic(format_args!( + "Cannot {op} object of type `{}` {by_what}", + left_ty.display(self.db()) + )); + } + + true + }; + // Check for division by zero or shift by a negative value; this doesn't change the inferred // type for the expression, but may emit a diagnostic if !emitted_bad_rhs_diagnostic { @@ -6481,9 +6477,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ( ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod, Type::IntLiteral(0) | Type::BooleanLiteral(false), - ) => self.check_bad_rhs(node, op, left_ty), + ) => check_bad_rhs(), (ast::Operator::LShift | ast::Operator::RShift, Type::IntLiteral(n)) if n < 0 => { - self.check_bad_rhs(node, op, left_ty) + check_bad_rhs() } _ => false, };