From 17eb65b26f3c451305792bcff2a09e2bdc0c3a91 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Mon, 2 Sep 2024 14:40:09 +0530 Subject: [PATCH] Add definitions for match statement (#13147) ## Summary This PR adds definition for match patterns. ## Test Plan Update the existing test case for match statement symbols to verify that the definitions are added as well. --- .../src/ast_node_ref.rs | 2 +- .../red_knot_python_semantic/src/node_key.rs | 19 ++++- .../src/semantic_index.rs | 20 ++++- .../src/semantic_index/builder.rs | 82 +++++++++++++++++-- .../src/semantic_index/definition.rs | 55 +++++++++++++ .../src/types/infer.rs | 27 +++++- 6 files changed, 189 insertions(+), 16 deletions(-) diff --git a/crates/red_knot_python_semantic/src/ast_node_ref.rs b/crates/red_knot_python_semantic/src/ast_node_ref.rs index 94f7d5d268..6ea0267c0b 100644 --- a/crates/red_knot_python_semantic/src/ast_node_ref.rs +++ b/crates/red_knot_python_semantic/src/ast_node_ref.rs @@ -31,10 +31,10 @@ impl AstNodeRef { /// which the `AstNodeRef` belongs. /// /// ## Safety + /// /// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the /// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that /// the invariant `node belongs to parsed` is upheld. - pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self { Self { _parsed: parsed, diff --git a/crates/red_knot_python_semantic/src/node_key.rs b/crates/red_knot_python_semantic/src/node_key.rs index 0935a1f839..9683b0e7fa 100644 --- a/crates/red_knot_python_semantic/src/node_key.rs +++ b/crates/red_knot_python_semantic/src/node_key.rs @@ -1,12 +1,18 @@ -use ruff_python_ast::{AnyNodeRef, NodeKind}; +use ruff_python_ast::{AnyNodeRef, Identifier, NodeKind}; use ruff_text_size::{Ranged, TextRange}; +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub(super) enum Kind { + Node(NodeKind), + Identifier, +} + /// Compact key for a node for use in a hash map. /// /// Compares two nodes by their kind and text range. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub(super) struct NodeKey { - kind: NodeKind, + kind: Kind, range: TextRange, } @@ -17,8 +23,15 @@ impl NodeKey { { let node = node.into(); NodeKey { - kind: node.kind(), + kind: Kind::Node(node.kind()), range: node.range(), } } + + pub(super) fn from_identifier(identifier: &Identifier) -> Self { + NodeKey { + kind: Kind::Identifier, + range: identifier.range(), + } + } } diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 3e0f08e2f3..6a5c96842f 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -1073,7 +1073,7 @@ def x(): } #[test] - fn match_stmt_symbols() { + fn match_stmt() { let TestCase { db, file } = test_case( " match subject: @@ -1087,13 +1087,27 @@ match subject: ", ); - let global_table = symbol_table(&db, global_scope(&db, file)); + let global_scope_id = global_scope(&db, file); + let global_table = symbol_table(&db, global_scope_id); assert!(global_table.symbol_by_name("Foo").unwrap().is_used()); assert_eq!( names(&global_table), - vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"] + vec!["subject", "a", "b", "c", "d", "e", "f", "g", "h", "Foo", "i", "j", "k", "l"] ); + + let use_def = use_def_map(&db, global_scope_id); + for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] { + let definition = use_def + .first_public_definition( + global_table.symbol_id_by_name(name).expect("symbol exists"), + ) + .expect("Expected with item definition for {name}"); + assert!(matches!( + definition.node(&db), + DefinitionKind::MatchPattern(_) + )); + } } #[test] diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 543ae275c4..dfdab1ec71 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -26,7 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::Db; -use super::definition::WithItemDefinitionNodeRef; +use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef}; pub(super) struct SemanticIndexBuilder<'db> { // Builder state @@ -600,6 +600,17 @@ where self.visit_body(body); self.visit_body(orelse); } + ast::Stmt::Match(ast::StmtMatch { + subject, + cases, + range: _, + }) => { + self.add_standalone_expression(subject); + self.visit_expr(subject); + for case in cases { + self.visit_match_case(case); + } + } _ => { walk_stmt(self, stmt); } @@ -803,22 +814,77 @@ where } fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { - if let ast::Pattern::MatchAs(ast::PatternMatchAs { - name: Some(name), .. - }) - | ast::Pattern::MatchStar(ast::PatternMatchStar { + // The definition visitor will recurse into the pattern so avoid walking it here. + let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern); + definition_visitor.visit_pattern(pattern); + } +} + +/// A visitor that adds symbols and definitions for the identifiers in a match pattern. +struct MatchPatternDefinitionVisitor<'a, 'db> { + /// The semantic index builder in which to add the symbols and definitions. + builder: &'a mut SemanticIndexBuilder<'db>, + /// The index of the current node in the pattern. + index: u32, + /// The pattern being visited. This pattern is the outermost pattern that is being visited + /// and is required to add the definitions. + pattern: &'a ast::Pattern, +} + +impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> { + fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self { + Self { + index: 0, + builder, + pattern, + } + } + + fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) { + let symbol = self + .builder + .add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED); + self.builder.add_definition( + symbol, + MatchPatternDefinitionNodeRef { + pattern: self.pattern, + identifier, + index: self.index, + }, + ); + } +} + +impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db> +where + 'ast: 'db, +{ + fn visit_expr(&mut self, expr: &'ast ast::Expr) { + self.builder.visit_expr(expr); + } + + fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { + if let ast::Pattern::MatchStar(ast::PatternMatchStar { name: Some(name), range: _, + }) = pattern + { + self.add_symbol_and_definition(name); + } + + walk_pattern(self, pattern); + + if let ast::Pattern::MatchAs(ast::PatternMatchAs { + name: Some(name), .. }) | ast::Pattern::MatchMapping(ast::PatternMatchMapping { rest: Some(name), .. }) = pattern { - // TODO(dhruvmanila): Add definition - self.add_or_update_symbol(name.id.clone(), SymbolFlags::IS_DEFINED); + self.add_symbol_and_definition(name); } - walk_pattern(self, pattern); + self.index += 1; } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 24b4b8e23f..75c95a4bd5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -49,6 +49,7 @@ pub(crate) enum DefinitionNodeRef<'a> { Comprehension(ComprehensionDefinitionNodeRef<'a>), Parameter(ast::AnyParameterRef<'a>), WithItem(WithItemDefinitionNodeRef<'a>), + MatchPattern(MatchPatternDefinitionNodeRef<'a>), } impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { @@ -123,6 +124,12 @@ impl<'a> From> for DefinitionNodeRef<'a> { } } +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(node: MatchPatternDefinitionNodeRef<'a>) -> Self { + Self::MatchPattern(node) + } +} + #[derive(Copy, Clone, Debug)] pub(crate) struct ImportFromDefinitionNodeRef<'a> { pub(crate) node: &'a ast::StmtImportFrom, @@ -153,6 +160,17 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> { pub(crate) first: bool, } +#[derive(Copy, Clone, Debug)] +pub(crate) struct MatchPatternDefinitionNodeRef<'a> { + /// The outermost pattern node in which the identifier being defined occurs. + pub(crate) pattern: &'a ast::Pattern, + /// The identifier being defined. + pub(crate) identifier: &'a ast::Identifier, + /// The index of the identifier in the pattern when visiting the `pattern` node in evaluation + /// order. + pub(crate) index: u32, +} + impl DefinitionNodeRef<'_> { #[allow(unsafe_code)] pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind { @@ -213,6 +231,15 @@ impl DefinitionNodeRef<'_> { target: AstNodeRef::new(parsed, target), }) } + DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef { + pattern, + identifier, + index, + }) => DefinitionKind::MatchPattern(MatchPatternDefinitionKind { + pattern: AstNodeRef::new(parsed.clone(), pattern), + identifier: AstNodeRef::new(parsed, identifier), + index, + }), } } @@ -241,6 +268,9 @@ impl DefinitionNodeRef<'_> { ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(), }, Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(), + Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => { + identifier.into() + } } } } @@ -260,6 +290,25 @@ pub enum DefinitionKind { Parameter(AstNodeRef), ParameterWithDefault(AstNodeRef), WithItem(WithItemDefinitionKind), + MatchPattern(MatchPatternDefinitionKind), +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub struct MatchPatternDefinitionKind { + pattern: AstNodeRef, + identifier: AstNodeRef, + index: u32, +} + +impl MatchPatternDefinitionKind { + pub(crate) fn pattern(&self) -> &ast::Pattern { + self.pattern.node() + } + + pub(crate) fn index(&self) -> u32 { + self.index + } } #[derive(Clone, Debug)] @@ -410,3 +459,9 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey { Self(NodeKey::from_node(node)) } } + +impl From<&ast::Identifier> for DefinitionNodeKey { + fn from(identifier: &ast::Identifier) -> Self { + Self(NodeKey::from_identifier(identifier)) + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 5d0abf46b7..66162f9e5a 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -416,6 +416,13 @@ impl<'db> TypeInferenceBuilder<'db> { DefinitionKind::WithItem(with_item) => { self.infer_with_item_definition(with_item.target(), with_item.node(), definition); } + DefinitionKind::MatchPattern(match_pattern) => { + self.infer_match_pattern_definition( + match_pattern.pattern(), + match_pattern.index(), + definition, + ); + } } } @@ -795,7 +802,10 @@ impl<'db> TypeInferenceBuilder<'db> { cases, } = match_statement; - self.infer_expression(subject); + let expression = self.index.expression(subject.as_ref()); + let result = infer_expression_types(self.db, expression); + self.extend(result); + for case in cases { let ast::MatchCase { range: _, @@ -809,7 +819,22 @@ impl<'db> TypeInferenceBuilder<'db> { } } + fn infer_match_pattern_definition( + &mut self, + _pattern: &ast::Pattern, + _index: u32, + definition: Definition<'db>, + ) { + // TODO(dhruvmanila): The correct way to infer types here is to perform structural matching + // against the subject expression type (which we can query via `infer_expression_types`) + // and extract the type at the `index` position if the pattern matches. This will be + // similar to the logic in `self.infer_assignment_definition`. + self.types.definitions.insert(definition, Type::Unknown); + } + fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { + // TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against + // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 match pattern { ast::Pattern::MatchValue(match_value) => { self.infer_expression(&match_value.value);