diff --git a/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs b/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs index cbb8b10b4e..137b346cda 100644 --- a/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs +++ b/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs @@ -3,12 +3,13 @@ use std::str::FromStr; use ruff_text_size::TextRange; use rustc_hash::FxHashMap; use rustpython_format::cformat::{CFormatPart, CFormatSpec, CFormatStrOrBytes, CFormatString}; -use rustpython_parser::ast::{self, Constant, Expr, Operator, Ranged}; +use rustpython_parser::ast::{self, Constant, Expr, Ranged}; use rustpython_parser::{lexer, Mode, Tok}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::str::{leading_quote, trailing_quote}; +use ruff_python_semantic::analyze::type_inference::PythonType; use crate::checkers::ast::Checker; @@ -38,87 +39,6 @@ impl Violation for BadStringFormatType { } } -#[derive(Debug, Copy, Clone)] -enum DataType { - String, - Integer, - Float, - Object, - Unknown, -} - -impl From<&Expr> for DataType { - fn from(expr: &Expr) -> Self { - match expr { - Expr::NamedExpr(ast::ExprNamedExpr { value, .. }) => (&**value).into(), - Expr::UnaryOp(ast::ExprUnaryOp { operand, .. }) => (&**operand).into(), - Expr::Dict(_) => DataType::Object, - Expr::Set(_) => DataType::Object, - Expr::ListComp(_) => DataType::Object, - Expr::SetComp(_) => DataType::Object, - Expr::DictComp(_) => DataType::Object, - Expr::GeneratorExp(_) => DataType::Object, - Expr::JoinedStr(_) => DataType::String, - Expr::BinOp(ast::ExprBinOp { left, op, .. }) => { - // Ex) "a" % "b" - if matches!( - left.as_ref(), - Expr::Constant(ast::ExprConstant { - value: Constant::Str(..), - .. - }) - ) && matches!(op, Operator::Mod) - { - return DataType::String; - } - DataType::Unknown - } - Expr::Constant(ast::ExprConstant { value, .. }) => match value { - Constant::Str(_) => DataType::String, - Constant::Int(_) => DataType::Integer, - Constant::Float(_) => DataType::Float, - _ => DataType::Unknown, - }, - Expr::List(_) => DataType::Object, - Expr::Tuple(_) => DataType::Object, - _ => DataType::Unknown, - } - } -} - -impl DataType { - fn is_compatible_with(self, format: FormatType) -> bool { - match self { - DataType::String => matches!( - format, - FormatType::Unknown | FormatType::String | FormatType::Repr - ), - DataType::Object => matches!( - format, - FormatType::Unknown | FormatType::String | FormatType::Repr - ), - DataType::Integer => matches!( - format, - FormatType::Unknown - | FormatType::String - | FormatType::Repr - | FormatType::Integer - | FormatType::Float - | FormatType::Number - ), - DataType::Float => matches!( - format, - FormatType::Unknown - | FormatType::String - | FormatType::Repr - | FormatType::Float - | FormatType::Number - ), - DataType::Unknown => true, - } - } -} - #[derive(Debug, Copy, Clone)] enum FormatType { Repr, @@ -129,6 +49,45 @@ enum FormatType { Unknown, } +impl FormatType { + fn is_compatible_with(self, data_type: PythonType) -> bool { + match data_type { + PythonType::String + | PythonType::Bytes + | PythonType::List + | PythonType::Dict + | PythonType::Set + | PythonType::Tuple + | PythonType::Generator + | PythonType::Complex + | PythonType::Bool + | PythonType::Ellipsis + | PythonType::None => matches!( + self, + FormatType::Unknown | FormatType::String | FormatType::Repr + ), + PythonType::Integer => matches!( + self, + FormatType::Unknown + | FormatType::String + | FormatType::Repr + | FormatType::Integer + | FormatType::Float + | FormatType::Number + ), + PythonType::Float => matches!( + self, + FormatType::Unknown + | FormatType::String + | FormatType::Repr + | FormatType::Float + | FormatType::Number + ), + PythonType::Unknown => true, + } + } +} + impl From for FormatType { fn from(format: char) -> Self { match format { @@ -159,9 +118,9 @@ fn collect_specs(formats: &[CFormatStrOrBytes]) -> Vec<&CFormatSpec> { /// Return `true` if the format string is equivalent to the constant type fn equivalent(format: &CFormatSpec, value: &Expr) -> bool { - let constant: DataType = value.into(); let format: FormatType = format.format_char.into(); - constant.is_compatible_with(format) + let constant: PythonType = value.into(); + format.is_compatible_with(constant) } /// Return `true` if the [`Constnat`] aligns with the format type. diff --git a/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs b/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs index eeaeb44d2e..486c8a474e 100644 --- a/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs +++ b/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs @@ -1,8 +1,9 @@ -use rustpython_parser::ast::{self, Constant, Expr, Ranged, Stmt}; +use rustpython_parser::ast::{Ranged, Stmt}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::{helpers::ReturnStatementVisitor, statement_visitor::StatementVisitor}; +use ruff_python_semantic::analyze::type_inference::PythonType; use crate::checkers::ast::Checker; @@ -39,37 +40,21 @@ pub(crate) fn invalid_str_return(checker: &mut Checker, name: &str, body: &[Stmt }; for stmt in returns { - // Disallow implicit `None`. - let Some(value) = stmt.value.as_deref() else { - checker.diagnostics.push(Diagnostic::new(InvalidStrReturnType, stmt.range())); - continue; - }; - - // Disallow other constants. - if matches!( - value, - Expr::List(_) - | Expr::Dict(_) - | Expr::Set(_) - | Expr::ListComp(_) - | Expr::DictComp(_) - | Expr::SetComp(_) - | Expr::GeneratorExp(_) - | Expr::Constant(ast::ExprConstant { - value: Constant::None - | Constant::Bool(_) - | Constant::Bytes(_) - | Constant::Int(_) - | Constant::Tuple(_) - | Constant::Float(_) - | Constant::Complex { .. } - | Constant::Ellipsis, - .. - }) - ) { + if let Some(value) = stmt.value.as_deref() { + // Disallow other, non- + if !matches!( + PythonType::from(value), + PythonType::String | PythonType::Unknown + ) { + checker + .diagnostics + .push(Diagnostic::new(InvalidStrReturnType, value.range())); + } + } else { + // Disallow implicit `None`. checker .diagnostics - .push(Diagnostic::new(InvalidStrReturnType, value.range())); + .push(Diagnostic::new(InvalidStrReturnType, stmt.range())); } } } diff --git a/crates/ruff_python_semantic/src/analyze/mod.rs b/crates/ruff_python_semantic/src/analyze/mod.rs index a4cd2fdf50..f8cb066480 100644 --- a/crates/ruff_python_semantic/src/analyze/mod.rs +++ b/crates/ruff_python_semantic/src/analyze/mod.rs @@ -1,5 +1,6 @@ pub mod branch_detection; pub mod function_type; pub mod logging; +pub mod type_inference; pub mod typing; pub mod visibility; diff --git a/crates/ruff_python_semantic/src/analyze/type_inference.rs b/crates/ruff_python_semantic/src/analyze/type_inference.rs new file mode 100644 index 0000000000..e040fe2805 --- /dev/null +++ b/crates/ruff_python_semantic/src/analyze/type_inference.rs @@ -0,0 +1,96 @@ +//! Analysis rules to perform basic type inference on individual expressions. + +use rustpython_parser::ast; +use rustpython_parser::ast::{Constant, Expr}; + +/// An extremely simple type inference system for individual expressions. +/// +/// This system can only represent and infer the types of simple data types +/// such as strings, integers, floats, and containers. It cannot infer the +/// types of variables or expressions that are not statically known from +/// individual AST nodes alone. +#[derive(Debug, Copy, Clone)] +pub enum PythonType { + /// A string literal, such as `"hello"`. + String, + /// A bytes literal, such as `b"hello"`. + Bytes, + /// An integer literal, such as `1` or `0x1`. + Integer, + /// A floating-point literal, such as `1.0` or `1e10`. + Float, + /// A complex literal, such as `1j` or `1+1j`. + Complex, + /// A boolean literal, such as `True` or `False`. + Bool, + /// A `None` literal, such as `None`. + None, + /// An ellipsis literal, such as `...`. + Ellipsis, + /// A dictionary literal, such as `{}` or `{"a": 1}`. + Dict, + /// A list literal, such as `[]` or `[i for i in range(3)]`. + List, + /// A set literal, such as `set()` or `{i for i in range(3)}`. + Set, + /// A tuple literal, such as `()` or `(1, 2, 3)`. + Tuple, + /// A generator expression, such as `(x for x in range(10))`. + Generator, + /// An unknown type, such as a variable or function call. + Unknown, +} + +impl From<&Expr> for PythonType { + fn from(expr: &Expr) -> Self { + match expr { + Expr::NamedExpr(ast::ExprNamedExpr { value, .. }) => (&**value).into(), + Expr::UnaryOp(ast::ExprUnaryOp { operand, .. }) => (&**operand).into(), + Expr::Dict(_) => PythonType::Dict, + Expr::DictComp(_) => PythonType::Dict, + Expr::Set(_) => PythonType::Set, + Expr::SetComp(_) => PythonType::Set, + Expr::List(_) => PythonType::List, + Expr::ListComp(_) => PythonType::List, + Expr::Tuple(_) => PythonType::Tuple, + Expr::GeneratorExp(_) => PythonType::Generator, + Expr::JoinedStr(_) => PythonType::String, + Expr::BinOp(ast::ExprBinOp { left, op, .. }) => { + // Ex) "a" % "b" + if op.is_mod() { + if matches!( + left.as_ref(), + Expr::Constant(ast::ExprConstant { + value: Constant::Str(..), + .. + }) + ) { + return PythonType::String; + } + if matches!( + left.as_ref(), + Expr::Constant(ast::ExprConstant { + value: Constant::Bytes(..), + .. + }) + ) { + return PythonType::Bytes; + } + } + PythonType::Unknown + } + Expr::Constant(ast::ExprConstant { value, .. }) => match value { + Constant::Str(_) => PythonType::String, + Constant::Int(_) => PythonType::Integer, + Constant::Float(_) => PythonType::Float, + Constant::Bool(_) => PythonType::Bool, + Constant::Complex { .. } => PythonType::Complex, + Constant::None => PythonType::None, + Constant::Ellipsis => PythonType::Ellipsis, + Constant::Bytes(_) => PythonType::Bytes, + Constant::Tuple(_) => PythonType::Tuple, + }, + _ => PythonType::Unknown, + } + } +} diff --git a/crates/ruff_python_semantic/src/analyze/typing.rs b/crates/ruff_python_semantic/src/analyze/typing.rs index 6e3ce7ad19..694e97e650 100644 --- a/crates/ruff_python_semantic/src/analyze/typing.rs +++ b/crates/ruff_python_semantic/src/analyze/typing.rs @@ -1,3 +1,5 @@ +//! Analysis rules for the `typing` module. + use rustpython_parser::ast::{self, Constant, Expr, Operator}; use num_traits::identities::Zero;