diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 56c7e31d85..1d6aa7aafb 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -463,6 +463,25 @@ mod tests { )); } + #[test] + fn augmented_assignment() { + let TestCase { db, file } = test_case("x += 1"); + let scope = global_scope(&db, file); + let global_table = symbol_table(&db, scope); + + assert_eq!(names(&global_table), vec!["x"]); + + let use_def = use_def_map(&db, scope); + let definition = use_def + .first_public_definition(global_table.symbol_id_by_name("x").unwrap()) + .unwrap(); + + assert!(matches!( + definition.node(&db), + DefinitionKind::AugmentedAssignment(_) + )); + } + #[test] fn class_scope() { let TestCase { db, file } = test_case( diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 5a9b63d7fa..d6a7b82151 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -495,6 +495,20 @@ where self.visit_expr(&node.target); self.current_assignment = None; } + ast::Stmt::AugAssign( + aug_assign @ ast::StmtAugAssign { + range: _, + target, + op: _, + value, + }, + ) => { + debug_assert!(self.current_assignment.is_none()); + self.visit_expr(value); + self.current_assignment = Some(aug_assign.into()); + self.visit_expr(target); + self.current_assignment = None; + } ast::Stmt::If(node) => { self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); @@ -563,12 +577,21 @@ where match expr { ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => { - let flags = match ctx { + let mut flags = match ctx { ast::ExprContext::Load => SymbolFlags::IS_USED, ast::ExprContext::Store => SymbolFlags::IS_DEFINED, ast::ExprContext::Del => SymbolFlags::IS_DEFINED, ast::ExprContext::Invalid => SymbolFlags::empty(), }; + if matches!( + self.current_assignment, + Some(CurrentAssignment::AugAssign(_)) + ) && !ctx.is_invalid() + { + // For augmented assignment, the target expression is also used, so we should + // record that as a use. + flags |= SymbolFlags::IS_USED; + } let symbol = self.add_or_update_symbol(id.clone(), flags); if flags.contains(SymbolFlags::IS_DEFINED) { match self.current_assignment { @@ -584,6 +607,9 @@ where Some(CurrentAssignment::AnnAssign(ann_assign)) => { self.add_definition(symbol, ann_assign); } + Some(CurrentAssignment::AugAssign(aug_assign)) => { + self.add_definition(symbol, aug_assign); + } Some(CurrentAssignment::Named(named)) => { // TODO(dhruvmanila): If the current scope is a comprehension, then the // named expression is implicitly nonlocal. This is yet to be @@ -727,6 +753,7 @@ where enum CurrentAssignment<'a> { Assign(&'a ast::StmtAssign), AnnAssign(&'a ast::StmtAnnAssign), + AugAssign(&'a ast::StmtAugAssign), Named(&'a ast::ExprNamed), Comprehension { node: &'a ast::Comprehension, @@ -746,6 +773,12 @@ impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> { } } +impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> { + fn from(value: &'a ast::StmtAugAssign) -> Self { + Self::AugAssign(value) + } +} + impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { fn from(value: &'a ast::ExprNamed) -> Self { Self::Named(value) diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index e0d6211ac9..38ccaf5849 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -44,6 +44,7 @@ pub(crate) enum DefinitionNodeRef<'a> { NamedExpression(&'a ast::ExprNamed), Assignment(AssignmentDefinitionNodeRef<'a>), AnnotatedAssignment(&'a ast::StmtAnnAssign), + AugmentedAssignment(&'a ast::StmtAugAssign), Comprehension(ComprehensionDefinitionNodeRef<'a>), Parameter(ast::AnyParameterRef<'a>), } @@ -72,6 +73,12 @@ impl<'a> From<&'a ast::StmtAnnAssign> for DefinitionNodeRef<'a> { } } +impl<'a> From<&'a ast::StmtAugAssign> for DefinitionNodeRef<'a> { + fn from(node: &'a ast::StmtAugAssign) -> Self { + Self::AugmentedAssignment(node) + } +} + impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> { fn from(node_ref: &'a ast::Alias) -> Self { Self::Import(node_ref) @@ -151,6 +158,9 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::AnnotatedAssignment(assign) => { DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) } + DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => { + DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment)) + } DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => { DefinitionKind::Comprehension(ComprehensionDefinitionKind { node: AstNodeRef::new(parsed, node), @@ -182,6 +192,7 @@ impl DefinitionNodeRef<'_> { target, }) => target.into(), Self::AnnotatedAssignment(node) => node.into(), + Self::AugmentedAssignment(node) => node.into(), Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(), Self::Parameter(node) => match node { ast::AnyParameterRef::Variadic(parameter) => parameter.into(), @@ -200,6 +211,7 @@ pub enum DefinitionKind { NamedExpression(AstNodeRef), Assignment(AssignmentDefinitionKind), AnnotatedAssignment(AstNodeRef), + AugmentedAssignment(AstNodeRef), Comprehension(ComprehensionDefinitionKind), Parameter(AstNodeRef), ParameterWithDefault(AstNodeRef), @@ -293,6 +305,12 @@ impl From<&ast::StmtAnnAssign> for DefinitionNodeKey { } } +impl From<&ast::StmtAugAssign> for DefinitionNodeKey { + fn from(node: &ast::StmtAugAssign) -> Self { + Self(NodeKey::from_node(node)) + } +} + impl From<&ast::Comprehension> for DefinitionNodeKey { fn from(node: &ast::Comprehension) -> Self { Self(NodeKey::from_node(node)) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index fc4f696666..7cad269146 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -303,6 +303,9 @@ impl<'db> TypeInferenceBuilder<'db> { DefinitionKind::AnnotatedAssignment(annotated_assignment) => { self.infer_annotated_assignment_definition(annotated_assignment.node(), definition); } + DefinitionKind::AugmentedAssignment(augmented_assignment) => { + self.infer_augment_assignment_definition(augmented_assignment.node(), definition); + } DefinitionKind::NamedExpression(named_expression) => { self.infer_named_expression_definition(named_expression.node(), definition); } @@ -763,15 +766,35 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_augmented_assignment_statement(&mut self, assignment: &ast::StmtAugAssign) { - // TODO this should be a Definition + if assignment.target.is_name_expr() { + self.infer_definition(assignment); + } else { + // TODO currently we don't consider assignments to non-Names to be Definitions + self.infer_augment_assignment(assignment); + } + } + + fn infer_augment_assignment_definition( + &mut self, + assignment: &ast::StmtAugAssign, + definition: Definition<'db>, + ) { + let target_ty = self.infer_augment_assignment(assignment); + self.types.definitions.insert(definition, target_ty); + } + + fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> { let ast::StmtAugAssign { range: _, target, op: _, value, } = assignment; - self.infer_expression(target); self.infer_expression(value); + self.infer_expression(target); + + // TODO(dhruvmanila): Resolve the target type using the value type and the operator + Type::Unknown } fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {