diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 96a2c97eb8..0ea7b471ab 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -1497,48 +1497,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// /// Returns the result of the `infer_value_ty` closure, which is called with the declared type /// as type context. - fn add_binding( + fn add_binding<'a>( &mut self, - node: AnyNodeRef, + node: AnyNodeRef<'a>, binding: Definition<'db>, - infer_value_ty: impl FnOnce(&mut Self, TypeContext<'db>) -> Type<'db>, - ) -> Type<'db> { - /// Arbitrary `__getitem__`/`__setitem__` methods on a class do not - /// necessarily guarantee that the passed-in value for `__setitem__` is stored and - /// can be retrieved unmodified via `__getitem__`. Therefore, we currently only - /// perform assignment-based narrowing on a few built-in classes (`list`, `dict`, - /// `bytesarray`, `TypedDict` and `collections` types) where we are confident that - /// this kind of narrowing can be performed soundly. This is the same approach as - /// pyright. TODO: Other standard library classes may also be considered safe. Also, - /// subclasses of these safe classes that do not override `__getitem__/__setitem__` - /// may be considered safe. - fn is_safe_mutable_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { - const SAFE_MUTABLE_CLASSES: &[KnownClass] = &[ - KnownClass::List, - KnownClass::Dict, - KnownClass::Bytearray, - KnownClass::DefaultDict, - KnownClass::ChainMap, - KnownClass::Counter, - KnownClass::Deque, - KnownClass::OrderedDict, - ]; - - SAFE_MUTABLE_CLASSES - .iter() - .map(|class| class.to_instance(db)) - .any(|safe_mutable_class| { - ty.is_equivalent_to(db, safe_mutable_class) - || ty - .generic_origin(db) - .zip(safe_mutable_class.generic_origin(db)) - .is_some_and(|(l, r)| l == r) - }) - } - + ) -> AddBinding<'db, 'a> { + let db = self.db(); debug_assert!( binding - .kind(self.db()) + .kind(db) .category(self.context.in_stub(), self.module()) .is_binding() ); @@ -1689,97 +1656,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } .or_else(|| resolved_place.ignore_possibly_undefined()); - let inferred_ty = infer_value_ty(self, TypeContext::new(declared_ty)); - - let declared_ty = declared_ty.unwrap_or(Type::unknown()); - let mut bound_ty = inferred_ty; - - if qualifiers.contains(TypeQualifiers::FINAL) { - let mut previous_bindings = use_def.bindings_at_definition(binding); - - // An assignment to a local `Final`-qualified symbol is only an error if there are prior bindings - - let previous_definition = previous_bindings - .next() - .and_then(|r| r.binding.definition()); - - if !is_local || previous_definition.is_some() { - let place = place_table.place(binding.place(db)); - if let Some(builder) = self.context.report_lint( - &INVALID_ASSIGNMENT, - binding.full_range(self.db(), self.module()), - ) { - let mut diagnostic = builder.into_diagnostic(format_args!( - "Reassignment of `Final` symbol `{place}` is not allowed" - )); - - diagnostic.set_primary_message("Reassignment of `Final` symbol"); - - if let Some(previous_definition) = previous_definition { - // It is not very helpful to show the previous definition if it results from - // an import. Ideally, we would show the original definition in the external - // module, but that information is currently not threaded through attribute - // lookup. - if !previous_definition.kind(db).is_import() { - if let DefinitionKind::AnnotatedAssignment(assignment) = - previous_definition.kind(db) - { - let range = assignment.annotation(self.module()).range(); - diagnostic.annotate( - self.context - .secondary(range) - .message("Symbol declared as `Final` here"), - ); - } else { - let range = - previous_definition.full_range(self.db(), self.module()); - diagnostic.annotate( - self.context - .secondary(range) - .message("Symbol declared as `Final` here"), - ); - } - diagnostic.set_primary_message("Symbol later reassigned here"); - } - } - } - } + AddBinding { + declared_ty, + binding, + node, + qualifiers, + is_local, } - - if !bound_ty.is_assignable_to(db, declared_ty) { - report_invalid_assignment(&self.context, node, binding, declared_ty, bound_ty); - - // Allow declarations to override inference in case of invalid assignment. - bound_ty = declared_ty; - } - // In the following cases, the bound type may not be the same as the RHS value type. - if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { - let value_ty = self.try_expression_type(value).unwrap_or_else(|| { - self.infer_maybe_standalone_expression(value, TypeContext::default()) - }); - // If the member is a data descriptor, the RHS value may differ from the value actually assigned. - if value_ty - .class_member(db, attr.id.clone()) - .place - .ignore_possibly_undefined() - .is_some_and(|ty| ty.may_be_data_descriptor(db)) - { - bound_ty = declared_ty; - } - } else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node { - let value_ty = self - .try_expression_type(value) - .unwrap_or_else(|| self.infer_expression(value, TypeContext::default())); - - if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) { - bound_ty = declared_ty; - } - } - - self.bindings - .insert(binding, bound_ty, self.multi_inference_state); - - inferred_ty } /// Returns `true` if `symbol_id` should be looked up in the global scope, skipping intervening @@ -2625,7 +2508,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// outer scope here. fn infer_parameter_definition( &mut self, - parameter_with_default: &ast::ParameterWithDefault, + parameter_with_default: &'ast ast::ParameterWithDefault, definition: Definition<'db>, ) { let ast::ParameterWithDefault { @@ -2677,7 +2560,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::unknown() }; - self.add_binding(parameter.into(), definition, |_, _| ty); + self.add_binding(parameter.into(), definition) + .insert(self, ty); } } @@ -2690,7 +2574,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// [`infer_parameter_definition`]: Self::infer_parameter_definition fn infer_variadic_positional_parameter_definition( &mut self, - parameter: &ast::Parameter, + parameter: &'ast ast::Parameter, definition: Definition<'db>, ) { if let Some(annotation) = parameter.annotation() { @@ -2741,9 +2625,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &DeclaredAndInferredType::are_the_same_type(ty), ); } else { - self.add_binding(parameter.into(), definition, |builder, _| { - Type::homogeneous_tuple(builder.db(), Type::unknown()) - }); + let inferred_ty = Type::homogeneous_tuple(self.db(), Type::unknown()); + self.add_binding(parameter.into(), definition) + .insert(self, inferred_ty); } } @@ -2832,7 +2716,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// [`infer_parameter_definition`]: Self::infer_parameter_definition fn infer_variadic_keyword_parameter_definition( &mut self, - parameter: &ast::Parameter, + parameter: &'ast ast::Parameter, definition: Definition<'db>, ) { if let Some(annotation) = parameter.annotation() { @@ -2884,12 +2768,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &DeclaredAndInferredType::are_the_same_type(ty), ); } else { - self.add_binding(parameter.into(), definition, |builder, _| { - KnownClass::Dict.to_specialized_instance( - builder.db(), - [KnownClass::Str.to_instance(builder.db()), Type::unknown()], - ) - }); + let inferred_ty = KnownClass::Dict.to_specialized_instance( + self.db(), + [KnownClass::Str.to_instance(self.db()), Type::unknown()], + ); + + self.add_binding(parameter.into(), definition) + .insert(self, inferred_ty); } } @@ -3237,7 +3122,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; self.store_expression_type(target, target_ty); - self.add_binding(target.into(), definition, |_, _| target_ty); + self.add_binding(target.into(), definition) + .insert(self, target_ty); } /// Infers the type of a context expression (`with expr`) and returns the target's type @@ -3392,8 +3278,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.add_binding( except_handler_definition.node(self.module()).into(), definition, - |_, _| symbol_ty, - ); + ) + .insert(self, symbol_ty); } fn infer_typevar_definition( @@ -3790,7 +3676,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_match_pattern_definition( &mut self, - pattern: &ast::Pattern, + pattern: &'ast ast::Pattern, _index: u32, definition: Definition<'db>, ) { @@ -3798,9 +3684,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // against the subject expression type (which we can query via `infer_expression_types`) // and extract the type at the `index` position if the pattern matches. This will be // similar to the logic in `self.infer_assignment_definition`. - self.add_binding(pattern.into(), definition, |_, _| { - todo_type!("`match` pattern definition types") - }); + self.add_binding(pattern.into(), definition) + .insert(self, todo_type!("`match` pattern definition types")); } fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { @@ -5315,11 +5200,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) { let target = assignment.target(self.module()); - self.add_binding(target.into(), definition, |builder, tcx| { - let target_ty = builder.infer_assignment_definition_impl(assignment, definition, tcx); - builder.store_expression_type(target, target_ty); - target_ty - }); + let add = self.add_binding(target.into(), definition); + let target_ty = + self.infer_assignment_definition_impl(assignment, definition, add.type_context()); + self.store_expression_type(target, target_ty); + add.insert(self, target_ty); } fn infer_assignment_definition_impl( @@ -6307,11 +6192,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_augment_assignment_definition( &mut self, - assignment: &ast::StmtAugAssign, + assignment: &'ast ast::StmtAugAssign, definition: Definition<'db>, ) { let target_ty = self.infer_augment_assignment(assignment); - self.add_binding(assignment.into(), definition, |_, _| target_ty); + self.add_binding(assignment.into(), definition) + .insert(self, target_ty); } fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> { @@ -6411,7 +6297,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; self.store_expression_type(target, loop_var_value_type); - self.add_binding(target.into(), definition, |_, _| loop_var_value_type); + self.add_binding(target.into(), definition) + .insert(self, loop_var_value_type); } fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { @@ -6936,16 +6823,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// is all we need to be doing here). fn infer_import_from_submodule_definition( &mut self, - import_from: &ast::StmtImportFrom, + import_from: &'ast ast::StmtImportFrom, definition: Definition<'db>, ) { // Get this package's absolute module name by resolving `.`, and make sure it exists let Ok(thispackage_name) = ModuleName::package_for_file(self.db(), self.file()) else { - self.add_binding(import_from.into(), definition, |_, _| Type::unknown()); + self.add_binding(import_from.into(), definition) + .insert(self, Type::unknown()); return; }; let Some(module) = resolve_module(self.db(), self.file(), &thispackage_name) else { - self.add_binding(import_from.into(), definition, |_, _| Type::unknown()); + self.add_binding(import_from.into(), definition) + .insert(self, Type::unknown()); return; }; @@ -6969,7 +6858,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .next() .and_then(ModuleName::new) }) else { - self.add_binding(import_from.into(), definition, |_, _| Type::unknown()); + self.add_binding(import_from.into(), definition) + .insert(self, Type::unknown()); return; }; @@ -6984,12 +6874,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // We explicitly don't introduce a *declaration* because it's actual ok // (and fairly common) to overwrite this import with a function or class // and we don't want it to be a type error to do so. - self.add_binding(import_from.into(), definition, |_, _| submodule_type); + self.add_binding(import_from.into(), definition) + .insert(self, submodule_type); return; } // That didn't work, try to produce diagnostics - self.add_binding(import_from.into(), definition, |_, _| Type::unknown()); + self.add_binding(import_from.into(), definition) + .insert(self, Type::unknown()); if !self.is_reachable(import_from) { return; @@ -8519,7 +8411,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; self.expressions.insert(target.into(), target_type); - self.add_binding(target.into(), definition, |_, _| target_type); + self.add_binding(target.into(), definition) + .insert(self, target_type); } fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> { @@ -8539,7 +8432,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn infer_named_expression_definition( &mut self, - named: &ast::ExprNamed, + named: &'ast ast::ExprNamed, definition: Definition<'db>, ) -> Type<'db> { let ast::ExprNamed { @@ -8549,11 +8442,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { value, } = named; - self.add_binding(named.target.as_ref().into(), definition, |builder, tcx| { - let ty = builder.infer_expression(value, tcx); - builder.store_expression_type(target, ty); - ty - }) + let add = self.add_binding(named.target.as_ref().into(), definition); + + let ty = self.infer_expression(value, add.type_context()); + self.store_expression_type(target, ty); + add.insert(self, ty) } fn infer_if_expression( @@ -13510,3 +13403,161 @@ impl IntoIterator for VecSet { self.0.into_iter() } } + +#[must_use] +struct AddBinding<'db, 'ast> { + declared_ty: Option>, + binding: Definition<'db>, + node: AnyNodeRef<'ast>, + qualifiers: TypeQualifiers, + is_local: bool, +} + +impl<'db, 'ast> AddBinding<'db, 'ast> { + fn type_context(&self) -> TypeContext<'db> { + TypeContext::new(self.declared_ty) + } + + fn insert( + self, + builder: &mut TypeInferenceBuilder<'db, 'ast>, + inferred_ty: Type<'db>, + ) -> Type<'db> { + let declared_ty = self.declared_ty.unwrap_or(Type::unknown()); + + let db = builder.db(); + let file_scope_id = self.binding.file_scope(db); + let use_def = builder.index.use_def_map(file_scope_id); + let place_table = builder.index.place_table(file_scope_id); + + let mut bound_ty = inferred_ty; + + if self.qualifiers.contains(TypeQualifiers::FINAL) { + let mut previous_bindings = use_def.bindings_at_definition(self.binding); + + // An assignment to a local `Final`-qualified symbol is only an error if there are prior bindings + + let previous_definition = previous_bindings + .next() + .and_then(|r| r.binding.definition()); + + if !self.is_local || previous_definition.is_some() { + let place = place_table.place(self.binding.place(db)); + if let Some(diag_builder) = builder.context.report_lint( + &INVALID_ASSIGNMENT, + self.binding.full_range(builder.db(), builder.module()), + ) { + let mut diagnostic = diag_builder.into_diagnostic(format_args!( + "Reassignment of `Final` symbol `{place}` is not allowed" + )); + + diagnostic.set_primary_message("Reassignment of `Final` symbol"); + + if let Some(previous_definition) = previous_definition { + // It is not very helpful to show the previous definition if it results from + // an import. Ideally, we would show the original definition in the external + // module, but that information is currently not threaded through attribute + // lookup. + if !previous_definition.kind(db).is_import() { + if let DefinitionKind::AnnotatedAssignment(assignment) = + previous_definition.kind(db) + { + let range = assignment.annotation(builder.module()).range(); + diagnostic.annotate( + builder + .context + .secondary(range) + .message("Symbol declared as `Final` here"), + ); + } else { + let range = previous_definition.full_range(db, builder.module()); + diagnostic.annotate( + builder + .context + .secondary(range) + .message("Symbol declared as `Final` here"), + ); + } + diagnostic.set_primary_message("Symbol later reassigned here"); + } + } + } + } + } + + if !bound_ty.is_assignable_to(db, declared_ty) { + report_invalid_assignment( + &builder.context, + self.node, + self.binding, + declared_ty, + bound_ty, + ); + + // Allow declarations to override inference in case of invalid assignment. + bound_ty = declared_ty; + } + // In the following cases, the bound type may not be the same as the RHS value type. + if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = self.node { + let value_ty = builder.try_expression_type(value).unwrap_or_else(|| { + builder.infer_maybe_standalone_expression(value, TypeContext::default()) + }); + // If the member is a data descriptor, the RHS value may differ from the value actually assigned. + if value_ty + .class_member(db, attr.id.clone()) + .place + .ignore_possibly_undefined() + .is_some_and(|ty| ty.may_be_data_descriptor(db)) + { + bound_ty = declared_ty; + } + } else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = self.node { + let value_ty = builder + .try_expression_type(value) + .unwrap_or_else(|| builder.infer_expression(value, TypeContext::default())); + + if !value_ty.is_typed_dict() && !Self::is_safe_mutable_class(db, value_ty) { + bound_ty = declared_ty; + } + } + + builder + .bindings + .insert(self.binding, bound_ty, builder.multi_inference_state); + + inferred_ty + } + + /// Arbitrary `__getitem__`/`__setitem__` methods on a class do not + /// necessarily guarantee that the passed-in value for `__setitem__` is stored and + /// can be retrieved unmodified via `__getitem__`. Therefore, we currently only + /// perform assignment-based narrowing on a few built-in classes (`list`, `dict`, + /// `bytesarray`, `TypedDict` and `collections` types) where we are confident that + /// this kind of narrowing can be performed soundly. This is the same approach as + /// pyright. TODO: Other standard library classes may also be considered safe. Also, + /// subclasses of these safe classes that do not override `__getitem__/__setitem__` + /// may be considered safe. + fn is_safe_mutable_class(db: &'db dyn Db, ty: Type<'db>) -> bool { + const SAFE_MUTABLE_CLASSES: &[KnownClass] = &[ + KnownClass::List, + KnownClass::Dict, + KnownClass::Bytearray, + KnownClass::DefaultDict, + KnownClass::ChainMap, + KnownClass::Counter, + KnownClass::Deque, + KnownClass::OrderedDict, + ]; + + SAFE_MUTABLE_CLASSES + .iter() + .map(|class| class.to_instance(db)) + .any(|safe_mutable_class| { + ty.is_equivalent_to(db, safe_mutable_class) + || ty + .generic_origin(db) + .zip(safe_mutable_class.generic_origin(db)) + .is_some_and(|(l, r)| l == r) + }) + } +}