From 31bddef98fd99cf467f0164267361f04ac928c59 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 24 May 2023 10:10:15 -0400 Subject: [PATCH] Visit `TypeVar` and `NewType` name arguments (#4627) --- crates/ruff/src/checkers/ast/mod.rs | 85 +++++++++++++++++---------- crates/ruff/src/rules/pyflakes/mod.rs | 10 ++++ 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 96a92216f7..f48cded720 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -3685,32 +3685,42 @@ where match callable { Some(Callable::Bool) => { self.visit_expr(func); - if !args.is_empty() { - self.visit_boolean_test(&args[0]); + let mut args = args.iter(); + if let Some(arg) = args.next() { + self.visit_boolean_test(arg); } - for expr in args.iter().skip(1) { - self.visit_expr(expr); + for arg in args { + self.visit_expr(arg); } } Some(Callable::Cast) => { self.visit_expr(func); - if !args.is_empty() { - self.visit_type_definition(&args[0]); + let mut args = args.iter(); + if let Some(arg) = args.next() { + self.visit_type_definition(arg); } - for expr in args.iter().skip(1) { - self.visit_expr(expr); + for arg in args { + self.visit_expr(arg); } } Some(Callable::NewType) => { self.visit_expr(func); - for expr in args.iter().skip(1) { - self.visit_type_definition(expr); + let mut args = args.iter(); + if let Some(arg) = args.next() { + self.visit_non_type_definition(arg); + } + for arg in args { + self.visit_type_definition(arg); } } Some(Callable::TypeVar) => { self.visit_expr(func); - for expr in args.iter().skip(1) { - self.visit_type_definition(expr); + let mut args = args.iter(); + if let Some(arg) = args.next() { + self.visit_non_type_definition(arg); + } + for arg in args { + self.visit_type_definition(arg); } for keyword in keywords { let Keyword { @@ -3731,24 +3741,30 @@ where self.visit_expr(func); // Ex) NamedTuple("a", [("a", int)]) - if args.len() > 1 { - match &args[1] { - Expr::List(ast::ExprList { elts, .. }) - | Expr::Tuple(ast::ExprTuple { elts, .. }) => { - for elt in elts { - match elt { - Expr::List(ast::ExprList { elts, .. }) - | Expr::Tuple(ast::ExprTuple { elts, .. }) => { - if elts.len() == 2 { - self.visit_non_type_definition(&elts[0]); - self.visit_type_definition(&elts[1]); - } - } - _ => {} + let mut args = args.iter(); + if let Some(arg) = args.next() { + self.visit_non_type_definition(arg); + } + for arg in args { + if let Expr::List(ast::ExprList { elts, .. }) + | Expr::Tuple(ast::ExprTuple { elts, .. }) = arg + { + for elt in elts { + match elt { + Expr::List(ast::ExprList { elts, .. }) + | Expr::Tuple(ast::ExprTuple { elts, .. }) + if elts.len() == 2 => + { + self.visit_non_type_definition(&elts[0]); + self.visit_type_definition(&elts[1]); + } + _ => { + self.visit_non_type_definition(elt); } } } - _ => {} + } else { + self.visit_non_type_definition(arg); } } @@ -3762,12 +3778,16 @@ where self.visit_expr(func); // Ex) TypedDict("a", {"a": int}) - if args.len() > 1 { + let mut args = args.iter(); + if let Some(arg) = args.next() { + self.visit_non_type_definition(arg); + } + for arg in args { if let Expr::Dict(ast::ExprDict { keys, values, range: _, - }) = &args[1] + }) = arg { for key in keys.iter().flatten() { self.visit_non_type_definition(key); @@ -3775,6 +3795,8 @@ where for value in values { self.visit_type_definition(value); } + } else { + self.visit_non_type_definition(arg); } } @@ -3787,11 +3809,12 @@ where Some(Callable::MypyExtension) => { self.visit_expr(func); - if let Some(arg) = args.first() { + let mut args = args.iter(); + if let Some(arg) = args.next() { // Ex) DefaultNamedArg(bool | None, name="some_prop_name") self.visit_type_definition(arg); - for arg in args.iter().skip(1) { + for arg in args { self.visit_non_type_definition(arg); } for keyword in keywords { diff --git a/crates/ruff/src/rules/pyflakes/mod.rs b/crates/ruff/src/rules/pyflakes/mod.rs index 2fd06eb44a..cd5ac2c234 100644 --- a/crates/ruff/src/rules/pyflakes/mod.rs +++ b/crates/ruff/src/rules/pyflakes/mod.rs @@ -3689,6 +3689,16 @@ mod tests { "#, &[], ); + flakes( + r#" + from typing import NewType + + def f(): + name = "x" + NewType(name, int) + "#, + &[], + ); } #[test]