add tests for bools and make helper method private

This commit is contained in:
Alex Waygood 2025-07-11 13:43:35 +01:00
parent 59570beb57
commit 3aa91a853e
3 changed files with 60 additions and 45 deletions

View File

@ -57,6 +57,25 @@ reveal_type(a ^ a) # revealed: Literal[False]
reveal_type(a ^ b) # revealed: Literal[True]
reveal_type(b ^ a) # revealed: Literal[True]
reveal_type(b ^ b) # revealed: Literal[False]
# left-shift
reveal_type(a << a) # revealed: Literal[2]
reveal_type(a << b) # revealed: Literal[1]
reveal_type(b << a) # revealed: Literal[0]
reveal_type(b << b) # revealed: Literal[0]
reveal_type(True << 100) # revealed: int
# error: [literal-math-error] "Cannot left shift object of type `Literal[True]` by a negative value"
reveal_type(True << -1) # revealed: int
# right-shift
reveal_type(a >> a) # revealed: Literal[0]
reveal_type(a >> b) # revealed: Literal[1]
reveal_type(b >> a) # revealed: Literal[0]
reveal_type(b >> b) # revealed: Literal[0]
# error: [literal-math-error] "Cannot right shift object of type `Literal[False]` by a negative value"
reveal_type(False >> -1) # revealed: int
```
## Arithmetic with a variable

View File

@ -151,7 +151,7 @@ reveal_type(MyInt(3) / 0) # revealed: int | float
## Bit-shifting
Literal artithmetic is supported for bit-shifting operations on `int`s:
Literal arithmetic is supported for bit-shifting operations on `int`s:
```py
reveal_type(42 << 3) # revealed: Literal[336]

View File

@ -1512,48 +1512,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
/// Emit a diagnostic if the given type cannot be divided by zero, or is shifted by a negative
/// value.
///
/// Expects the resolved type of the left side of the binary expression.
fn check_bad_rhs(&mut self, node: AnyNodeRef<'_>, op: ast::Operator, left: Type<'db>) -> bool {
let lhs_int = match left {
Type::BooleanLiteral(_) | Type::IntLiteral(_) => true,
Type::NominalInstance(instance)
if matches!(
instance.class.known(self.db()),
Some(KnownClass::Int | KnownClass::Bool)
) =>
{
true
}
Type::NominalInstance(instance)
if matches!(instance.class.known(self.db()), Some(KnownClass::Float)) =>
{
false
}
_ => return false,
};
let (op, by_what) = match (op, lhs_int) {
(ast::Operator::Div, _) => ("divide", "by zero"),
(ast::Operator::FloorDiv, _) => ("floor divide", "by zero"),
(ast::Operator::Mod, _) => ("reduce", "modulo zero"),
(ast::Operator::LShift, true) => ("left shift", "by a negative value"),
(ast::Operator::RShift, true) => ("right shift", "by a negative value"),
_ => return false,
};
if let Some(builder) = self.context.report_lint(&LITERAL_MATH_ERROR, node) {
builder.into_diagnostic(format_args!(
"Cannot {op} object of type `{}` {by_what}",
left.display(self.db())
));
}
true
}
fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) {
debug_assert!(
binding
@ -6474,6 +6432,44 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
right_ty: Type<'db>,
op: ast::Operator,
) -> Option<Type<'db>> {
let check_bad_rhs = || {
let lhs_int = match left_ty {
Type::BooleanLiteral(_) | Type::IntLiteral(_) => true,
Type::NominalInstance(instance)
if matches!(
instance.class.known(self.db()),
Some(KnownClass::Int | KnownClass::Bool)
) =>
{
true
}
Type::NominalInstance(instance)
if matches!(instance.class.known(self.db()), Some(KnownClass::Float)) =>
{
false
}
_ => return false,
};
let (op, by_what) = match (op, lhs_int) {
(ast::Operator::Div, _) => ("divide", "by zero"),
(ast::Operator::FloorDiv, _) => ("floor divide", "by zero"),
(ast::Operator::Mod, _) => ("reduce", "modulo zero"),
(ast::Operator::LShift, true) => ("left shift", "by a negative value"),
(ast::Operator::RShift, true) => ("right shift", "by a negative value"),
_ => return false,
};
if let Some(builder) = self.context.report_lint(&LITERAL_MATH_ERROR, node) {
builder.into_diagnostic(format_args!(
"Cannot {op} object of type `{}` {by_what}",
left_ty.display(self.db())
));
}
true
};
// Check for division by zero or shift by a negative value; this doesn't change the inferred
// type for the expression, but may emit a diagnostic
if !emitted_bad_rhs_diagnostic {
@ -6481,9 +6477,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
(
ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod,
Type::IntLiteral(0) | Type::BooleanLiteral(false),
) => self.check_bad_rhs(node, op, left_ty),
) => check_bad_rhs(),
(ast::Operator::LShift | ast::Operator::RShift, Type::IntLiteral(n)) if n < 0 => {
self.check_bad_rhs(node, op, left_ty)
check_bad_rhs()
}
_ => false,
};