diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C417.py b/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C417.py index dd7b93f890..7c2cdfbd78 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C417.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C417.py @@ -75,3 +75,7 @@ list(map(lambda x, y: x, [(1, 2), (3, 4)])) _ = t"{set(map(lambda x: x % 2 == 0, nums))}" _ = t"{dict(map(lambda v: (v, v**2), nums))}" + +# See https://github.com/astral-sh/ruff/issues/20198 +# No error: lambda contains `yield`, so map() should not be rewritten +map(lambda x: (yield x), [1, 2, 3]) diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_map.rs b/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_map.rs index e31c4b15a0..362c03b68c 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_map.rs +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_map.rs @@ -122,6 +122,13 @@ pub(crate) fn unnecessary_map(checker: &Checker, call: &ast::ExprCall) { } }; + // If the lambda body contains a `yield` or `yield from`, rewriting `map(lambda ...)` to a + // generator expression or any comprehension is invalid Python syntax + // (e.g., `yield` is not allowed inside generator or comprehension expressions). In such cases, skip. + if lambda_contains_yield(&lambda.body) { + return; + } + for iterable in iterables { // For example, (x+1 for x in (c:=a)) is invalid syntax // so we can't suggest it. @@ -183,6 +190,13 @@ fn map_lambda_and_iterables<'a>( Some((lambda, iterables)) } +/// Returns true if the expression tree contains a `yield` or `yield from` expression. +fn lambda_contains_yield(expr: &Expr) -> bool { + any_over_expr(expr, &|expr| { + matches!(expr, Expr::Yield(_) | Expr::YieldFrom(_)) + }) +} + /// A lambda as the first argument to `map()` has the "expected" arity when: /// /// * It has exactly one parameter diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C417_C417.py.snap b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C417_C417.py.snap index 253c858611..31d0647726 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C417_C417.py.snap +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C417_C417.py.snap @@ -342,6 +342,7 @@ help: Replace `map()` with a set comprehension 75 + _ = t"{ {x % 2 == 0 for x in nums} }" 76 | _ = t"{dict(map(lambda v: (v, v**2), nums))}" 77 | +78 | note: This is an unsafe fix and may change runtime behavior C417 [*] Unnecessary `map()` usage (rewrite using a dict comprehension) @@ -359,4 +360,6 @@ help: Replace `map()` with a dict comprehension - _ = t"{dict(map(lambda v: (v, v**2), nums))}" 76 + _ = t"{ {v: v**2 for v in nums} }" 77 | +78 | +79 | # See https://github.com/astral-sh/ruff/issues/20198 note: This is an unsafe fix and may change runtime behavior