diff --git a/src/check_ast.rs b/src/check_ast.rs index f3003fb5af..0cd24dd84b 100644 --- a/src/check_ast.rs +++ b/src/check_ast.rs @@ -159,7 +159,13 @@ impl<'a> Checker<'a> { } /// Return `true` if the `Expr` is a reference to `typing.${target}`. - pub fn match_typing_module(&self, call_path: &[&str], target: &str) -> bool { + pub fn match_typing_expr(&self, expr: &Expr, target: &str) -> bool { + let call_path = dealias_call_path(collect_call_paths(expr), &self.import_aliases); + self.match_typing_call_path(&call_path, target) + } + + /// Return `true` if the call path is a reference to `typing.${target}`. + pub fn match_typing_call_path(&self, call_path: &[&str], target: &str) -> bool { match_call_path(call_path, "typing", target, &self.from_imports) || (typing::in_extensions(target) && match_call_path(call_path, "typing_extensions", target, &self.from_imports)) @@ -1058,7 +1064,7 @@ where pyupgrade::plugins::use_pep604_annotation(self, expr, value, slice); } - if self.match_typing_module(&collect_call_paths(value), "Literal") { + if self.match_typing_expr(value, "Literal") { self.in_literal = true; } @@ -1646,12 +1652,12 @@ where keywords, } => { let call_path = dealias_call_path(collect_call_paths(func), &self.import_aliases); - if self.match_typing_module(&call_path, "ForwardRef") { + if self.match_typing_call_path(&call_path, "ForwardRef") { self.visit_expr(func); for expr in args { self.visit_annotation(expr); } - } else if self.match_typing_module(&call_path, "cast") { + } else if self.match_typing_call_path(&call_path, "cast") { self.visit_expr(func); if !args.is_empty() { self.visit_annotation(&args[0]); @@ -1659,12 +1665,12 @@ where for expr in args.iter().skip(1) { self.visit_expr(expr); } - } else if self.match_typing_module(&call_path, "NewType") { + } else if self.match_typing_call_path(&call_path, "NewType") { self.visit_expr(func); for expr in args.iter().skip(1) { self.visit_annotation(expr); } - } else if self.match_typing_module(&call_path, "TypeVar") { + } else if self.match_typing_call_path(&call_path, "TypeVar") { self.visit_expr(func); for expr in args.iter().skip(1) { self.visit_annotation(expr); @@ -1681,7 +1687,7 @@ where } } } - } else if self.match_typing_module(&call_path, "NamedTuple") { + } else if self.match_typing_call_path(&call_path, "NamedTuple") { self.visit_expr(func); // Ex) NamedTuple("a", [("a", int)]) @@ -1713,7 +1719,7 @@ where let KeywordData { value, .. } = &keyword.node; self.visit_annotation(value); } - } else if self.match_typing_module(&call_path, "TypedDict") { + } else if self.match_typing_call_path(&call_path, "TypedDict") { self.visit_expr(func); // Ex) TypedDict("a", {"a": int}) diff --git a/src/flake8_annotations/plugins.rs b/src/flake8_annotations/plugins.rs index 0955bd59d5..95300d8bb9 100644 --- a/src/flake8_annotations/plugins.rs +++ b/src/flake8_annotations/plugins.rs @@ -1,6 +1,5 @@ use rustpython_ast::{Arguments, Constant, Expr, ExprKind, Stmt, StmtKind}; -use crate::ast::helpers::collect_call_paths; use crate::ast::types::Range; use crate::ast::visitor; use crate::ast::visitor::Visitor; @@ -54,7 +53,7 @@ fn check_dynamically_typed(checker: &mut Checker, annotation: &Expr, func: F) where F: FnOnce() -> String, { - if checker.match_typing_module(&collect_call_paths(annotation), "Any") { + if checker.match_typing_expr(annotation, "Any") { checker.add_check(Check::new( CheckKind::DynamicallyTypedExpression(func()), Range::from_located(annotation), diff --git a/src/pyupgrade/plugins/use_pep604_annotation.rs b/src/pyupgrade/plugins/use_pep604_annotation.rs index 762893c6e8..0ecba45541 100644 --- a/src/pyupgrade/plugins/use_pep604_annotation.rs +++ b/src/pyupgrade/plugins/use_pep604_annotation.rs @@ -1,6 +1,6 @@ use rustpython_ast::{Constant, Expr, ExprKind, Operator}; -use crate::ast::helpers::collect_call_paths; +use crate::ast::helpers::{collect_call_paths, dealias_call_path}; use crate::ast::types::Range; use crate::autofix::Fix; use crate::check_ast::Checker; @@ -44,8 +44,8 @@ fn union(elts: &[Expr]) -> Expr { /// U007 pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, slice: &Expr) { - let call_path = collect_call_paths(value); - if checker.match_typing_module(&call_path, "Optional") { + let call_path = dealias_call_path(collect_call_paths(value), &checker.import_aliases); + if checker.match_typing_call_path(&call_path, "Optional") { let mut check = Check::new(CheckKind::UsePEP604Annotation, Range::from_located(expr)); if checker.patch() { let mut generator = SourceGenerator::new(); @@ -60,7 +60,7 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s } } checker.add_check(check); - } else if checker.match_typing_module(&call_path, "Union") { + } else if checker.match_typing_call_path(&call_path, "Union") { let mut check = Check::new(CheckKind::UsePEP604Annotation, Range::from_located(expr)); if checker.patch() { match &slice.node {