[red-knot] Refactor `infer_chained_boolean_types` to have access to `TypeInferenceBuilder` (#16222)

This commit is contained in:
Micha Reiser 2025-02-19 10:13:35 +00:00 committed by GitHub
parent 01c3e6b94f
commit e84985e9b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 89 additions and 77 deletions

View File

@ -4025,59 +4025,72 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
values,
} = bool_op;
Self::infer_chained_boolean_types(
self.db(),
self.infer_chained_boolean_types(
*op,
values.iter().enumerate().map(|(index, value)| {
values.iter().enumerate(),
|builder, (index, value)| {
if index == values.len() - 1 {
self.infer_expression(value)
builder.infer_expression(value)
} else {
self.infer_standalone_expression(value)
builder.infer_standalone_expression(value)
}
}),
values.len(),
},
)
}
/// Computes the output of a chain of (one) boolean operation, consuming as input an iterator
/// of types. The iterator is consumed even if the boolean evaluation can be short-circuited,
/// of operations and calling the `infer_ty` for each to infer their types.
/// The iterator is consumed even if the boolean evaluation can be short-circuited,
/// in order to ensure the invariant that all expressions are evaluated when inferring types.
fn infer_chained_boolean_types(
db: &'db dyn Db,
fn infer_chained_boolean_types<Iterator, Item, F>(
&mut self,
op: ast::BoolOp,
values: impl IntoIterator<Item = Type<'db>>,
n_values: usize,
) -> Type<'db> {
operations: Iterator,
infer_ty: F,
) -> Type<'db>
where
Iterator: IntoIterator<Item = Item>,
F: Fn(&mut Self, Item) -> Type<'db>,
{
let mut done = false;
let db = self.db();
let elements = values.into_iter().enumerate().map(|(i, ty)| {
if done {
return Type::Never;
}
let elements = operations
.into_iter()
.with_position()
.map(|(position, ty)| {
let ty = infer_ty(self, ty);
let is_last = i == n_values - 1;
match (ty.bool(db), is_last, op) {
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
ty
if done {
return Type::Never;
}
(Truthiness::Ambiguous, false, _) => IntersectionBuilder::new(db)
.add_positive(ty)
.add_negative(match op {
ast::BoolOp::And => Type::AlwaysTruthy,
ast::BoolOp::Or => Type::AlwaysFalsy,
})
.build(),
let is_last = matches!(
position,
itertools::Position::Last | itertools::Position::Only
);
(_, true, _) => ty,
}
});
match (ty.bool(db), is_last, op) {
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
(Truthiness::AlwaysFalse, _, ast::BoolOp::And)
| (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
done = true;
ty
}
(Truthiness::Ambiguous, false, _) => IntersectionBuilder::new(db)
.add_positive(ty)
.add_negative(match op {
ast::BoolOp::And => Type::AlwaysTruthy,
ast::BoolOp::Or => Type::AlwaysFalsy,
})
.build(),
(_, true, _) => ty,
}
});
UnionType::from_elements(db, elements)
}
@ -4102,52 +4115,51 @@ impl<'db> TypeInferenceBuilder<'db> {
//
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
// is shared with the one in `infer_binary_type_comparison`.
Self::infer_chained_boolean_types(
self.db(),
self.infer_chained_boolean_types(
ast::BoolOp::And,
std::iter::once(&**left)
.chain(comparators)
.tuple_windows::<(_, _)>()
.zip(ops)
.map(|((left, right), op)| {
let left_ty = self.expression_type(left);
let right_ty = self.expression_type(right);
.zip(ops),
|builder, ((left, right), op)| {
let left_ty = builder.expression_type(left);
let right_ty = builder.expression_type(right);
self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|error| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.context.report_lint(
&UNSUPPORTED_OPERATOR,
AnyNodeRef::ExprCompare(compare),
format_args!(
"Operator `{}` is not supported for types `{}` and `{}`{}",
error.op,
error.left_ty.display(self.db()),
error.right_ty.display(self.db()),
if (left_ty, right_ty) == (error.left_ty, error.right_ty) {
String::new()
} else {
format!(
", in comparing `{}` with `{}`",
left_ty.display(self.db()),
right_ty.display(self.db())
)
}
),
);
builder
.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|error| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
builder.context.report_lint(
&UNSUPPORTED_OPERATOR,
AnyNodeRef::ExprCompare(compare),
format_args!(
"Operator `{}` is not supported for types `{}` and `{}`{}",
error.op,
error.left_ty.display(builder.db()),
error.right_ty.display(builder.db()),
if (left_ty, right_ty) == (error.left_ty, error.right_ty) {
String::new()
} else {
format!(
", in comparing `{}` with `{}`",
left_ty.display(builder.db()),
right_ty.display(builder.db())
)
}
),
);
match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
| ast::CmpOp::NotIn
| ast::CmpOp::Is
| ast::CmpOp::IsNot => KnownClass::Bool.to_instance(self.db()),
// Other operators can return arbitrary types
_ => Type::unknown(),
}
})
}),
ops.len(),
match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
| ast::CmpOp::NotIn
| ast::CmpOp::Is
| ast::CmpOp::IsNot => KnownClass::Bool.to_instance(builder.db()),
// Other operators can return arbitrary types
_ => Type::unknown(),
}
})
},
)
}