[ty] Reduce monomorphization in add_binding (#22196)

This commit is contained in:
Micha Reiser
2025-12-27 11:17:27 +01:00
committed by GitHub
parent 5d32ab8175
commit da188d5cf6

View File

@@ -1497,48 +1497,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
///
/// Returns the result of the `infer_value_ty` closure, which is called with the declared type
/// as type context.
fn add_binding(
fn add_binding<'a>(
&mut self,
node: AnyNodeRef,
node: AnyNodeRef<'a>,
binding: Definition<'db>,
infer_value_ty: impl FnOnce(&mut Self, TypeContext<'db>) -> Type<'db>,
) -> Type<'db> {
/// Arbitrary `__getitem__`/`__setitem__` methods on a class do not
/// necessarily guarantee that the passed-in value for `__setitem__` is stored and
/// can be retrieved unmodified via `__getitem__`. Therefore, we currently only
/// perform assignment-based narrowing on a few built-in classes (`list`, `dict`,
/// `bytesarray`, `TypedDict` and `collections` types) where we are confident that
/// this kind of narrowing can be performed soundly. This is the same approach as
/// pyright. TODO: Other standard library classes may also be considered safe. Also,
/// subclasses of these safe classes that do not override `__getitem__/__setitem__`
/// may be considered safe.
fn is_safe_mutable_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
const SAFE_MUTABLE_CLASSES: &[KnownClass] = &[
KnownClass::List,
KnownClass::Dict,
KnownClass::Bytearray,
KnownClass::DefaultDict,
KnownClass::ChainMap,
KnownClass::Counter,
KnownClass::Deque,
KnownClass::OrderedDict,
];
SAFE_MUTABLE_CLASSES
.iter()
.map(|class| class.to_instance(db))
.any(|safe_mutable_class| {
ty.is_equivalent_to(db, safe_mutable_class)
|| ty
.generic_origin(db)
.zip(safe_mutable_class.generic_origin(db))
.is_some_and(|(l, r)| l == r)
})
}
) -> AddBinding<'db, 'a> {
let db = self.db();
debug_assert!(
binding
.kind(self.db())
.kind(db)
.category(self.context.in_stub(), self.module())
.is_binding()
);
@@ -1689,97 +1656,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
.or_else(|| resolved_place.ignore_possibly_undefined());
let inferred_ty = infer_value_ty(self, TypeContext::new(declared_ty));
let declared_ty = declared_ty.unwrap_or(Type::unknown());
let mut bound_ty = inferred_ty;
if qualifiers.contains(TypeQualifiers::FINAL) {
let mut previous_bindings = use_def.bindings_at_definition(binding);
// An assignment to a local `Final`-qualified symbol is only an error if there are prior bindings
let previous_definition = previous_bindings
.next()
.and_then(|r| r.binding.definition());
if !is_local || previous_definition.is_some() {
let place = place_table.place(binding.place(db));
if let Some(builder) = self.context.report_lint(
&INVALID_ASSIGNMENT,
binding.full_range(self.db(), self.module()),
) {
let mut diagnostic = builder.into_diagnostic(format_args!(
"Reassignment of `Final` symbol `{place}` is not allowed"
));
diagnostic.set_primary_message("Reassignment of `Final` symbol");
if let Some(previous_definition) = previous_definition {
// It is not very helpful to show the previous definition if it results from
// an import. Ideally, we would show the original definition in the external
// module, but that information is currently not threaded through attribute
// lookup.
if !previous_definition.kind(db).is_import() {
if let DefinitionKind::AnnotatedAssignment(assignment) =
previous_definition.kind(db)
{
let range = assignment.annotation(self.module()).range();
diagnostic.annotate(
self.context
.secondary(range)
.message("Symbol declared as `Final` here"),
);
} else {
let range =
previous_definition.full_range(self.db(), self.module());
diagnostic.annotate(
self.context
.secondary(range)
.message("Symbol declared as `Final` here"),
);
}
diagnostic.set_primary_message("Symbol later reassigned here");
}
}
}
}
AddBinding {
declared_ty,
binding,
node,
qualifiers,
is_local,
}
if !bound_ty.is_assignable_to(db, declared_ty) {
report_invalid_assignment(&self.context, node, binding, declared_ty, bound_ty);
// Allow declarations to override inference in case of invalid assignment.
bound_ty = declared_ty;
}
// In the following cases, the bound type may not be the same as the RHS value type.
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node {
let value_ty = self.try_expression_type(value).unwrap_or_else(|| {
self.infer_maybe_standalone_expression(value, TypeContext::default())
});
// If the member is a data descriptor, the RHS value may differ from the value actually assigned.
if value_ty
.class_member(db, attr.id.clone())
.place
.ignore_possibly_undefined()
.is_some_and(|ty| ty.may_be_data_descriptor(db))
{
bound_ty = declared_ty;
}
} else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node {
let value_ty = self
.try_expression_type(value)
.unwrap_or_else(|| self.infer_expression(value, TypeContext::default()));
if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) {
bound_ty = declared_ty;
}
}
self.bindings
.insert(binding, bound_ty, self.multi_inference_state);
inferred_ty
}
/// Returns `true` if `symbol_id` should be looked up in the global scope, skipping intervening
@@ -2625,7 +2508,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// outer scope here.
fn infer_parameter_definition(
&mut self,
parameter_with_default: &ast::ParameterWithDefault,
parameter_with_default: &'ast ast::ParameterWithDefault,
definition: Definition<'db>,
) {
let ast::ParameterWithDefault {
@@ -2677,7 +2560,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown()
};
self.add_binding(parameter.into(), definition, |_, _| ty);
self.add_binding(parameter.into(), definition)
.insert(self, ty);
}
}
@@ -2690,7 +2574,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// [`infer_parameter_definition`]: Self::infer_parameter_definition
fn infer_variadic_positional_parameter_definition(
&mut self,
parameter: &ast::Parameter,
parameter: &'ast ast::Parameter,
definition: Definition<'db>,
) {
if let Some(annotation) = parameter.annotation() {
@@ -2741,9 +2625,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&DeclaredAndInferredType::are_the_same_type(ty),
);
} else {
self.add_binding(parameter.into(), definition, |builder, _| {
Type::homogeneous_tuple(builder.db(), Type::unknown())
});
let inferred_ty = Type::homogeneous_tuple(self.db(), Type::unknown());
self.add_binding(parameter.into(), definition)
.insert(self, inferred_ty);
}
}
@@ -2832,7 +2716,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// [`infer_parameter_definition`]: Self::infer_parameter_definition
fn infer_variadic_keyword_parameter_definition(
&mut self,
parameter: &ast::Parameter,
parameter: &'ast ast::Parameter,
definition: Definition<'db>,
) {
if let Some(annotation) = parameter.annotation() {
@@ -2884,12 +2768,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&DeclaredAndInferredType::are_the_same_type(ty),
);
} else {
self.add_binding(parameter.into(), definition, |builder, _| {
KnownClass::Dict.to_specialized_instance(
builder.db(),
[KnownClass::Str.to_instance(builder.db()), Type::unknown()],
)
});
let inferred_ty = KnownClass::Dict.to_specialized_instance(
self.db(),
[KnownClass::Str.to_instance(self.db()), Type::unknown()],
);
self.add_binding(parameter.into(), definition)
.insert(self, inferred_ty);
}
}
@@ -3237,7 +3122,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
self.store_expression_type(target, target_ty);
self.add_binding(target.into(), definition, |_, _| target_ty);
self.add_binding(target.into(), definition)
.insert(self, target_ty);
}
/// Infers the type of a context expression (`with expr`) and returns the target's type
@@ -3392,8 +3278,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.add_binding(
except_handler_definition.node(self.module()).into(),
definition,
|_, _| symbol_ty,
);
)
.insert(self, symbol_ty);
}
fn infer_typevar_definition(
@@ -3790,7 +3676,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_match_pattern_definition(
&mut self,
pattern: &ast::Pattern,
pattern: &'ast ast::Pattern,
_index: u32,
definition: Definition<'db>,
) {
@@ -3798,9 +3684,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// against the subject expression type (which we can query via `infer_expression_types`)
// and extract the type at the `index` position if the pattern matches. This will be
// similar to the logic in `self.infer_assignment_definition`.
self.add_binding(pattern.into(), definition, |_, _| {
todo_type!("`match` pattern definition types")
});
self.add_binding(pattern.into(), definition)
.insert(self, todo_type!("`match` pattern definition types"));
}
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
@@ -5315,11 +5200,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) {
let target = assignment.target(self.module());
self.add_binding(target.into(), definition, |builder, tcx| {
let target_ty = builder.infer_assignment_definition_impl(assignment, definition, tcx);
builder.store_expression_type(target, target_ty);
target_ty
});
let add = self.add_binding(target.into(), definition);
let target_ty =
self.infer_assignment_definition_impl(assignment, definition, add.type_context());
self.store_expression_type(target, target_ty);
add.insert(self, target_ty);
}
fn infer_assignment_definition_impl(
@@ -6307,11 +6192,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_augment_assignment_definition(
&mut self,
assignment: &ast::StmtAugAssign,
assignment: &'ast ast::StmtAugAssign,
definition: Definition<'db>,
) {
let target_ty = self.infer_augment_assignment(assignment);
self.add_binding(assignment.into(), definition, |_, _| target_ty);
self.add_binding(assignment.into(), definition)
.insert(self, target_ty);
}
fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> {
@@ -6411,7 +6297,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
self.store_expression_type(target, loop_var_value_type);
self.add_binding(target.into(), definition, |_, _| loop_var_value_type);
self.add_binding(target.into(), definition)
.insert(self, loop_var_value_type);
}
fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) {
@@ -6936,16 +6823,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// is all we need to be doing here).
fn infer_import_from_submodule_definition(
&mut self,
import_from: &ast::StmtImportFrom,
import_from: &'ast ast::StmtImportFrom,
definition: Definition<'db>,
) {
// Get this package's absolute module name by resolving `.`, and make sure it exists
let Ok(thispackage_name) = ModuleName::package_for_file(self.db(), self.file()) else {
self.add_binding(import_from.into(), definition, |_, _| Type::unknown());
self.add_binding(import_from.into(), definition)
.insert(self, Type::unknown());
return;
};
let Some(module) = resolve_module(self.db(), self.file(), &thispackage_name) else {
self.add_binding(import_from.into(), definition, |_, _| Type::unknown());
self.add_binding(import_from.into(), definition)
.insert(self, Type::unknown());
return;
};
@@ -6969,7 +6858,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.next()
.and_then(ModuleName::new)
}) else {
self.add_binding(import_from.into(), definition, |_, _| Type::unknown());
self.add_binding(import_from.into(), definition)
.insert(self, Type::unknown());
return;
};
@@ -6984,12 +6874,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// We explicitly don't introduce a *declaration* because it's actual ok
// (and fairly common) to overwrite this import with a function or class
// and we don't want it to be a type error to do so.
self.add_binding(import_from.into(), definition, |_, _| submodule_type);
self.add_binding(import_from.into(), definition)
.insert(self, submodule_type);
return;
}
// That didn't work, try to produce diagnostics
self.add_binding(import_from.into(), definition, |_, _| Type::unknown());
self.add_binding(import_from.into(), definition)
.insert(self, Type::unknown());
if !self.is_reachable(import_from) {
return;
@@ -8519,7 +8411,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
self.expressions.insert(target.into(), target_type);
self.add_binding(target.into(), definition, |_, _| target_type);
self.add_binding(target.into(), definition)
.insert(self, target_type);
}
fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> {
@@ -8539,7 +8432,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_named_expression_definition(
&mut self,
named: &ast::ExprNamed,
named: &'ast ast::ExprNamed,
definition: Definition<'db>,
) -> Type<'db> {
let ast::ExprNamed {
@@ -8549,11 +8442,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
value,
} = named;
self.add_binding(named.target.as_ref().into(), definition, |builder, tcx| {
let ty = builder.infer_expression(value, tcx);
builder.store_expression_type(target, ty);
ty
})
let add = self.add_binding(named.target.as_ref().into(), definition);
let ty = self.infer_expression(value, add.type_context());
self.store_expression_type(target, ty);
add.insert(self, ty)
}
fn infer_if_expression(
@@ -13510,3 +13403,161 @@ impl<V> IntoIterator for VecSet<V> {
self.0.into_iter()
}
}
#[must_use]
struct AddBinding<'db, 'ast> {
declared_ty: Option<Type<'db>>,
binding: Definition<'db>,
node: AnyNodeRef<'ast>,
qualifiers: TypeQualifiers,
is_local: bool,
}
impl<'db, 'ast> AddBinding<'db, 'ast> {
fn type_context(&self) -> TypeContext<'db> {
TypeContext::new(self.declared_ty)
}
fn insert(
self,
builder: &mut TypeInferenceBuilder<'db, 'ast>,
inferred_ty: Type<'db>,
) -> Type<'db> {
let declared_ty = self.declared_ty.unwrap_or(Type::unknown());
let db = builder.db();
let file_scope_id = self.binding.file_scope(db);
let use_def = builder.index.use_def_map(file_scope_id);
let place_table = builder.index.place_table(file_scope_id);
let mut bound_ty = inferred_ty;
if self.qualifiers.contains(TypeQualifiers::FINAL) {
let mut previous_bindings = use_def.bindings_at_definition(self.binding);
// An assignment to a local `Final`-qualified symbol is only an error if there are prior bindings
let previous_definition = previous_bindings
.next()
.and_then(|r| r.binding.definition());
if !self.is_local || previous_definition.is_some() {
let place = place_table.place(self.binding.place(db));
if let Some(diag_builder) = builder.context.report_lint(
&INVALID_ASSIGNMENT,
self.binding.full_range(builder.db(), builder.module()),
) {
let mut diagnostic = diag_builder.into_diagnostic(format_args!(
"Reassignment of `Final` symbol `{place}` is not allowed"
));
diagnostic.set_primary_message("Reassignment of `Final` symbol");
if let Some(previous_definition) = previous_definition {
// It is not very helpful to show the previous definition if it results from
// an import. Ideally, we would show the original definition in the external
// module, but that information is currently not threaded through attribute
// lookup.
if !previous_definition.kind(db).is_import() {
if let DefinitionKind::AnnotatedAssignment(assignment) =
previous_definition.kind(db)
{
let range = assignment.annotation(builder.module()).range();
diagnostic.annotate(
builder
.context
.secondary(range)
.message("Symbol declared as `Final` here"),
);
} else {
let range = previous_definition.full_range(db, builder.module());
diagnostic.annotate(
builder
.context
.secondary(range)
.message("Symbol declared as `Final` here"),
);
}
diagnostic.set_primary_message("Symbol later reassigned here");
}
}
}
}
}
if !bound_ty.is_assignable_to(db, declared_ty) {
report_invalid_assignment(
&builder.context,
self.node,
self.binding,
declared_ty,
bound_ty,
);
// Allow declarations to override inference in case of invalid assignment.
bound_ty = declared_ty;
}
// In the following cases, the bound type may not be the same as the RHS value type.
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = self.node {
let value_ty = builder.try_expression_type(value).unwrap_or_else(|| {
builder.infer_maybe_standalone_expression(value, TypeContext::default())
});
// If the member is a data descriptor, the RHS value may differ from the value actually assigned.
if value_ty
.class_member(db, attr.id.clone())
.place
.ignore_possibly_undefined()
.is_some_and(|ty| ty.may_be_data_descriptor(db))
{
bound_ty = declared_ty;
}
} else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = self.node {
let value_ty = builder
.try_expression_type(value)
.unwrap_or_else(|| builder.infer_expression(value, TypeContext::default()));
if !value_ty.is_typed_dict() && !Self::is_safe_mutable_class(db, value_ty) {
bound_ty = declared_ty;
}
}
builder
.bindings
.insert(self.binding, bound_ty, builder.multi_inference_state);
inferred_ty
}
/// Arbitrary `__getitem__`/`__setitem__` methods on a class do not
/// necessarily guarantee that the passed-in value for `__setitem__` is stored and
/// can be retrieved unmodified via `__getitem__`. Therefore, we currently only
/// perform assignment-based narrowing on a few built-in classes (`list`, `dict`,
/// `bytesarray`, `TypedDict` and `collections` types) where we are confident that
/// this kind of narrowing can be performed soundly. This is the same approach as
/// pyright. TODO: Other standard library classes may also be considered safe. Also,
/// subclasses of these safe classes that do not override `__getitem__/__setitem__`
/// may be considered safe.
fn is_safe_mutable_class(db: &'db dyn Db, ty: Type<'db>) -> bool {
const SAFE_MUTABLE_CLASSES: &[KnownClass] = &[
KnownClass::List,
KnownClass::Dict,
KnownClass::Bytearray,
KnownClass::DefaultDict,
KnownClass::ChainMap,
KnownClass::Counter,
KnownClass::Deque,
KnownClass::OrderedDict,
];
SAFE_MUTABLE_CLASSES
.iter()
.map(|class| class.to_instance(db))
.any(|safe_mutable_class| {
ty.is_equivalent_to(db, safe_mutable_class)
|| ty
.generic_origin(db)
.zip(safe_mutable_class.generic_origin(db))
.is_some_and(|(l, r)| l == r)
})
}
}