Check shifts of literals ints

This commit is contained in:
Brandt Bucher 2025-05-26 20:00:10 -07:00 committed by Alex Waygood
parent 4f60f0e925
commit 2882840abc
4 changed files with 136 additions and 54 deletions

View File

@ -141,3 +141,48 @@ class MyInt(int): ...
# No error for a subclass of int # No error for a subclass of int
reveal_type(MyInt(3) / 0) # revealed: int | float reveal_type(MyInt(3) / 0) # revealed: int | float
``` ```
## Bit-shifting
Literal artithmetic is supported for bit-shifting operations on `int`s:
```py
reveal_type(42 << 3) # revealed: Literal[336]
reveal_type(0 << 3) # revealed: Literal[0]
reveal_type(-42 << 3) # revealed: Literal[-336]
reveal_type(42 >> 3) # revealed: Literal[5]
reveal_type(0 >> 3) # revealed: Literal[0]
reveal_type(-42 >> 3) # revealed: Literal[-6]
```
If the result of a left shift overflows the `int` literal type, it becomes `int`. Right shifts do
not overflow:
```py
reveal_type(42 << 100) # revealed: int
reveal_type(0 << 100) # revealed: int
reveal_type(-42 << 100) # revealed: int
reveal_type(42 >> 100) # revealed: Literal[0]
reveal_type(0 >> 100) # revealed: Literal[0]
reveal_type(-42 >> 100) # revealed: Literal[-1]
```
It is an error to shift by a negative value. This is handled similarly to `division-by-zero`, above:
```py
# error: [negative-shift] "Cannot left shift object of type `Literal[42]` by a negative value"
reveal_type(42 << -3) # revealed: int
# error: [negative-shift] "Cannot left shift object of type `Literal[0]` by a negative value"
reveal_type(0 << -3) # revealed: int
# error: [negative-shift] "Cannot left shift object of type `Literal[-42]` by a negative value"
reveal_type(-42 << -3) # revealed: int
# error: [negative-shift] "Cannot right shift object of type `Literal[42]` by a negative value"
reveal_type(42 >> -3) # revealed: int
# error: [negative-shift] "Cannot right shift object of type `Literal[0]` by a negative value"
reveal_type(0 >> -3) # revealed: int
# error: [negative-shift] "Cannot right shift object of type `Literal[-42]` by a negative value"
reveal_type(-42 >> -3) # revealed: int
```

View File

@ -62,6 +62,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&INVALID_TYPE_GUARD_CALL); registry.register_lint(&INVALID_TYPE_GUARD_CALL);
registry.register_lint(&INVALID_TYPE_VARIABLE_CONSTRAINTS); registry.register_lint(&INVALID_TYPE_VARIABLE_CONSTRAINTS);
registry.register_lint(&MISSING_ARGUMENT); registry.register_lint(&MISSING_ARGUMENT);
registry.register_lint(&NEGATIVE_SHIFT);
registry.register_lint(&NO_MATCHING_OVERLOAD); registry.register_lint(&NO_MATCHING_OVERLOAD);
registry.register_lint(&NON_SUBSCRIPTABLE); registry.register_lint(&NON_SUBSCRIPTABLE);
registry.register_lint(&NOT_ITERABLE); registry.register_lint(&NOT_ITERABLE);
@ -1059,6 +1060,25 @@ declare_lint! {
} }
} }
declare_lint! {
/// ## What it does
/// Detects shifting an int by a negative value.
///
/// ## Why is this bad?
/// Shifting an int by a negative value raises a `ValueError` at runtime.
///
/// ## Examples
/// ```python
/// 42 >> -1
/// 42 << -1
/// ```
pub(crate) static NEGATIVE_SHIFT = {
summary: "detects shifting an int by a negative value",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! { declare_lint! {
/// ## What it does /// ## What it does
/// Checks for calls to an overloaded function that do not match any of the overloads. /// Checks for calls to an overloaded function that do not match any of the overloads.

View File

@ -92,7 +92,7 @@ use crate::types::diagnostic::{
CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_KW_ONLY, INCONSISTENT_MRO, CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_KW_ONLY, INCONSISTENT_MRO,
INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE, INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE,
INVALID_DECLARATION, INVALID_GENERIC_CLASS, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, INVALID_DECLARATION, INVALID_GENERIC_CLASS, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM,
INVALID_TYPE_GUARD_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, IncompatibleBases, INVALID_TYPE_GUARD_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, IncompatibleBases, NEGATIVE_SHIFT,
POSSIBLY_UNBOUND_IMPLICIT_CALL, POSSIBLY_UNBOUND_IMPORT, TypeCheckDiagnostics, POSSIBLY_UNBOUND_IMPLICIT_CALL, POSSIBLY_UNBOUND_IMPORT, TypeCheckDiagnostics,
UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE,
UNSUPPORTED_OPERATOR, report_implicit_return_type, report_instance_layout_conflict, UNSUPPORTED_OPERATOR, report_implicit_return_type, report_instance_layout_conflict,
@ -1512,35 +1512,43 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
} }
/// Raise a diagnostic if the given type cannot be divided by zero. /// Raise a diagnostic if the given type cannot be divided by zero, or is shifted by a negative
/// value.
/// ///
/// Expects the resolved type of the left side of the binary expression. /// Expects the resolved type of the left side of the binary expression.
fn check_division_by_zero( fn check_bad_rhs(&mut self, node: AnyNodeRef<'_>, op: ast::Operator, left: Type<'db>) -> bool {
&mut self, let lhs_int = match left {
node: AnyNodeRef<'_>, Type::BooleanLiteral(_) | Type::IntLiteral(_) => true,
op: ast::Operator,
left: Type<'db>,
) -> bool {
match left {
Type::BooleanLiteral(_) | Type::IntLiteral(_) => {}
Type::NominalInstance(instance) Type::NominalInstance(instance)
if matches!( if matches!(
instance.class.known(self.db()), instance.class.known(self.db()),
Some(KnownClass::Float | KnownClass::Int | KnownClass::Bool) Some(KnownClass::Int | KnownClass::Bool)
) => {} ) =>
_ => return false, {
} true
}
let (op, by_zero) = match op { Type::NominalInstance(instance)
ast::Operator::Div => ("divide", "by zero"), if matches!(instance.class.known(self.db()), Some(KnownClass::Float)) =>
ast::Operator::FloorDiv => ("floor divide", "by zero"), {
ast::Operator::Mod => ("reduce", "modulo zero"), false
}
_ => return false, _ => return false,
}; };
if let Some(builder) = self.context.report_lint(&DIVISION_BY_ZERO, node) { let (op, by_what, lint) = match (op, lhs_int) {
(ast::Operator::Div, _) => ("divide", "by zero", &DIVISION_BY_ZERO),
(ast::Operator::FloorDiv, _) => ("floor divide", "by zero", &DIVISION_BY_ZERO),
(ast::Operator::Mod, _) => ("reduce", "modulo zero", &DIVISION_BY_ZERO),
(ast::Operator::LShift, true) => ("left shift", "by a negative value", &NEGATIVE_SHIFT),
(ast::Operator::RShift, true) => {
("right shift", "by a negative value", &NEGATIVE_SHIFT)
}
_ => return false,
};
if let Some(builder) = self.context.report_lint(lint, node) {
builder.into_diagnostic(format_args!( builder.into_diagnostic(format_args!(
"Cannot {op} object of type `{}` {by_zero}", "Cannot {op} object of type `{}` {by_what}",
left.display(self.db()) left.display(self.db())
)); ));
} }
@ -6460,30 +6468,31 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_binary_expression_type( fn infer_binary_expression_type(
&mut self, &mut self,
node: AnyNodeRef<'_>, node: AnyNodeRef<'_>,
mut emitted_division_by_zero_diagnostic: bool, mut emitted_bad_rhs_diagnostic: bool,
left_ty: Type<'db>, left_ty: Type<'db>,
right_ty: Type<'db>, right_ty: Type<'db>,
op: ast::Operator, op: ast::Operator,
) -> Option<Type<'db>> { ) -> Option<Type<'db>> {
// Check for division by zero; this doesn't change the inferred type for the expression, but // Check for division by zero or shift by a negative value; this doesn't change the inferred
// may emit a diagnostic // type for the expression, but may emit a diagnostic
if !emitted_division_by_zero_diagnostic if !emitted_bad_rhs_diagnostic {
&& matches!( emitted_bad_rhs_diagnostic = match (op, right_ty) {
(op, right_ty),
( (
ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod, ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod,
Type::IntLiteral(0) | Type::BooleanLiteral(false) Type::IntLiteral(0) | Type::BooleanLiteral(false),
) ) => self.check_bad_rhs(node, op, left_ty),
) (ast::Operator::LShift | ast::Operator::RShift, Type::IntLiteral(n)) if n < 0 => {
{ self.check_bad_rhs(node, op, left_ty)
emitted_division_by_zero_diagnostic = self.check_division_by_zero(node, op, left_ty); }
_ => false,
};
} }
match (left_ty, right_ty, op) { match (left_ty, right_ty, op) {
(Type::Union(lhs_union), rhs, _) => lhs_union.try_map(self.db(), |lhs_element| { (Type::Union(lhs_union), rhs, _) => lhs_union.try_map(self.db(), |lhs_element| {
self.infer_binary_expression_type( self.infer_binary_expression_type(
node, node,
emitted_division_by_zero_diagnostic, emitted_bad_rhs_diagnostic,
*lhs_element, *lhs_element,
rhs, rhs,
op, op,
@ -6492,15 +6501,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
(lhs, Type::Union(rhs_union), _) => rhs_union.try_map(self.db(), |rhs_element| { (lhs, Type::Union(rhs_union), _) => rhs_union.try_map(self.db(), |rhs_element| {
self.infer_binary_expression_type( self.infer_binary_expression_type(
node, node,
emitted_division_by_zero_diagnostic, emitted_bad_rhs_diagnostic,
lhs, lhs,
*rhs_element, *rhs_element,
op, op,
) )
}), }),
// Non-todo Anys take precedence over Todos (as if we fix this `Todo` in the future,
// the result would then become Any or Unknown, respectively).
(any @ Type::Dynamic(DynamicType::Any), _, _) (any @ Type::Dynamic(DynamicType::Any), _, _)
| (_, any @ Type::Dynamic(DynamicType::Any), _) => Some(any), | (_, any @ Type::Dynamic(DynamicType::Any), _) => Some(any),
(unknown @ Type::Dynamic(DynamicType::Unknown), _, _) (unknown @ Type::Dynamic(DynamicType::Unknown), _, _)
@ -6595,6 +6602,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Some(Type::IntLiteral(n ^ m)) Some(Type::IntLiteral(n ^ m))
} }
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::LShift) => Some(
u32::try_from(m)
.ok()
.and_then(|m| n.checked_shl(m))
.map(Type::IntLiteral)
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db())),
),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::RShift) => Some(
u32::try_from(m)
.ok()
.map(|m| n >> m.clamp(0, 63))
.map(Type::IntLiteral)
.unwrap_or_else(|| KnownClass::Int.to_instance(self.db())),
),
(Type::BytesLiteral(lhs), Type::BytesLiteral(rhs), ast::Operator::Add) => { (Type::BytesLiteral(lhs), Type::BytesLiteral(rhs), ast::Operator::Add) => {
let bytes = [lhs.value(self.db()), rhs.value(self.db())].concat(); let bytes = [lhs.value(self.db()), rhs.value(self.db())].concat();
Some(Type::bytes_literal(self.db(), &bytes)) Some(Type::bytes_literal(self.db(), &bytes))
@ -6661,7 +6684,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
(Type::BooleanLiteral(b1), Type::BooleanLiteral(_) | Type::IntLiteral(_), op) => self (Type::BooleanLiteral(b1), Type::BooleanLiteral(_) | Type::IntLiteral(_), op) => self
.infer_binary_expression_type( .infer_binary_expression_type(
node, node,
emitted_division_by_zero_diagnostic, emitted_bad_rhs_diagnostic,
Type::IntLiteral(i64::from(b1)), Type::IntLiteral(i64::from(b1)),
right_ty, right_ty,
op, op,
@ -6669,7 +6692,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
(Type::IntLiteral(_), Type::BooleanLiteral(b2), op) => self (Type::IntLiteral(_), Type::BooleanLiteral(b2), op) => self
.infer_binary_expression_type( .infer_binary_expression_type(
node, node,
emitted_division_by_zero_diagnostic, emitted_bad_rhs_diagnostic,
left_ty, left_ty,
Type::IntLiteral(i64::from(b2)), Type::IntLiteral(i64::from(b2)),
op, op,
@ -6694,22 +6717,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
| Type::DataclassDecorator(_) | Type::DataclassDecorator(_)
| Type::DataclassTransformer(_) | Type::DataclassTransformer(_)
| Type::ModuleLiteral(_) | Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)
| Type::SubclassOf(_)
| Type::NominalInstance(_)
| Type::ProtocolInstance(_)
| Type::SpecialForm(_)
| Type::KnownInstance(_)
| Type::PropertyInstance(_)
| Type::Intersection(_)
| Type::AlwaysTruthy
| Type::AlwaysFalsy
| Type::IntLiteral(_)
| Type::StringLiteral(_)
| Type::LiteralString
| Type::BytesLiteral(_)
| Type::Tuple(_)
| Type::BoundSuper(_) | Type::BoundSuper(_)
| Type::TypeVar(_) | Type::TypeVar(_)
| Type::TypeIs(_), | Type::TypeIs(_),

10
ty.schema.json generated
View File

@ -671,6 +671,16 @@
} }
] ]
}, },
"negative-shift": {
"title": "detects shifting an int by a negative value",
"description": "## What it does\nDetects shifting an int by a negative value.\n\n## Why is this bad?\nShifting an int by a negative value raises a `ValueError` at runtime.\n\n## Examples\n```python\n42 >> -1\n42 << -1\n```",
"default": "error",
"oneOf": [
{
"$ref": "#/definitions/Level"
}
]
},
"no-matching-overload": { "no-matching-overload": {
"title": "detects calls that do not match any overload", "title": "detects calls that do not match any overload",
"description": "## What it does\nChecks for calls to an overloaded function that do not match any of the overloads.\n\n## Why is this bad?\nFailing to provide the correct arguments to one of the overloads will raise a `TypeError`\nat runtime.\n\n## Examples\n```python\n@overload\ndef func(x: int): ...\n@overload\ndef func(x: bool): ...\nfunc(\"string\") # error: [no-matching-overload]\n```", "description": "## What it does\nChecks for calls to an overloaded function that do not match any of the overloads.\n\n## Why is this bad?\nFailing to provide the correct arguments to one of the overloads will raise a `TypeError`\nat runtime.\n\n## Examples\n```python\n@overload\ndef func(x: int): ...\n@overload\ndef func(x: bool): ...\nfunc(\"string\") # error: [no-matching-overload]\n```",