diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 5c24ae8195..923abd79df 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -4025,59 +4025,72 @@ impl<'db> TypeInferenceBuilder<'db> { op, values, } = bool_op; - Self::infer_chained_boolean_types( - self.db(), + self.infer_chained_boolean_types( *op, - values.iter().enumerate().map(|(index, value)| { + values.iter().enumerate(), + |builder, (index, value)| { if index == values.len() - 1 { - self.infer_expression(value) + builder.infer_expression(value) } else { - self.infer_standalone_expression(value) + builder.infer_standalone_expression(value) } - }), - values.len(), + }, ) } /// Computes the output of a chain of (one) boolean operation, consuming as input an iterator - /// of types. The iterator is consumed even if the boolean evaluation can be short-circuited, + /// of operations and calling the `infer_ty` for each to infer their types. + /// The iterator is consumed even if the boolean evaluation can be short-circuited, /// in order to ensure the invariant that all expressions are evaluated when inferring types. - fn infer_chained_boolean_types( - db: &'db dyn Db, + fn infer_chained_boolean_types( + &mut self, op: ast::BoolOp, - values: impl IntoIterator>, - n_values: usize, - ) -> Type<'db> { + operations: Iterator, + infer_ty: F, + ) -> Type<'db> + where + Iterator: IntoIterator, + F: Fn(&mut Self, Item) -> Type<'db>, + { let mut done = false; + let db = self.db(); - let elements = values.into_iter().enumerate().map(|(i, ty)| { - if done { - return Type::Never; - } + let elements = operations + .into_iter() + .with_position() + .map(|(position, ty)| { + let ty = infer_ty(self, ty); - let is_last = i == n_values - 1; - - match (ty.bool(db), is_last, op) { - (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never, - (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never, - - (Truthiness::AlwaysFalse, _, ast::BoolOp::And) - | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => { - done = true; - ty + if done { + return Type::Never; } - (Truthiness::Ambiguous, false, _) => IntersectionBuilder::new(db) - .add_positive(ty) - .add_negative(match op { - ast::BoolOp::And => Type::AlwaysTruthy, - ast::BoolOp::Or => Type::AlwaysFalsy, - }) - .build(), + let is_last = matches!( + position, + itertools::Position::Last | itertools::Position::Only + ); - (_, true, _) => ty, - } - }); + match (ty.bool(db), is_last, op) { + (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never, + (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never, + + (Truthiness::AlwaysFalse, _, ast::BoolOp::And) + | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => { + done = true; + ty + } + + (Truthiness::Ambiguous, false, _) => IntersectionBuilder::new(db) + .add_positive(ty) + .add_negative(match op { + ast::BoolOp::And => Type::AlwaysTruthy, + ast::BoolOp::Or => Type::AlwaysFalsy, + }) + .build(), + + (_, true, _) => ty, + } + }); UnionType::from_elements(db, elements) } @@ -4102,52 +4115,51 @@ impl<'db> TypeInferenceBuilder<'db> { // // As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below // is shared with the one in `infer_binary_type_comparison`. - Self::infer_chained_boolean_types( - self.db(), + self.infer_chained_boolean_types( ast::BoolOp::And, std::iter::once(&**left) .chain(comparators) .tuple_windows::<(_, _)>() - .zip(ops) - .map(|((left, right), op)| { - let left_ty = self.expression_type(left); - let right_ty = self.expression_type(right); + .zip(ops), + |builder, ((left, right), op)| { + let left_ty = builder.expression_type(left); + let right_ty = builder.expression_type(right); - self.infer_binary_type_comparison(left_ty, *op, right_ty) - .unwrap_or_else(|error| { - // Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome) - self.context.report_lint( - &UNSUPPORTED_OPERATOR, - AnyNodeRef::ExprCompare(compare), - format_args!( - "Operator `{}` is not supported for types `{}` and `{}`{}", - error.op, - error.left_ty.display(self.db()), - error.right_ty.display(self.db()), - if (left_ty, right_ty) == (error.left_ty, error.right_ty) { - String::new() - } else { - format!( - ", in comparing `{}` with `{}`", - left_ty.display(self.db()), - right_ty.display(self.db()) - ) - } - ), - ); + builder + .infer_binary_type_comparison(left_ty, *op, right_ty) + .unwrap_or_else(|error| { + // Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome) + builder.context.report_lint( + &UNSUPPORTED_OPERATOR, + AnyNodeRef::ExprCompare(compare), + format_args!( + "Operator `{}` is not supported for types `{}` and `{}`{}", + error.op, + error.left_ty.display(builder.db()), + error.right_ty.display(builder.db()), + if (left_ty, right_ty) == (error.left_ty, error.right_ty) { + String::new() + } else { + format!( + ", in comparing `{}` with `{}`", + left_ty.display(builder.db()), + right_ty.display(builder.db()) + ) + } + ), + ); - match op { - // `in, not in, is, is not` always return bool instances - ast::CmpOp::In - | ast::CmpOp::NotIn - | ast::CmpOp::Is - | ast::CmpOp::IsNot => KnownClass::Bool.to_instance(self.db()), - // Other operators can return arbitrary types - _ => Type::unknown(), - } - }) - }), - ops.len(), + match op { + // `in, not in, is, is not` always return bool instances + ast::CmpOp::In + | ast::CmpOp::NotIn + | ast::CmpOp::Is + | ast::CmpOp::IsNot => KnownClass::Bool.to_instance(builder.db()), + // Other operators can return arbitrary types + _ => Type::unknown(), + } + }) + }, ) }