diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 0631232809..69ae1e7cc4 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1729,72 +1729,52 @@ impl<'db> TypeInferenceBuilder<'db> { let left_ty = self.infer_expression(left); let right_ty = self.infer_expression(right); - // TODO flatten the matches by matching on (left_ty, right_ty, op) - match left_ty { - Type::Any => Type::Any, - Type::Unknown => Type::Unknown, - Type::IntLiteral(n) => { - match right_ty { - Type::IntLiteral(m) => { - match op { - ast::Operator::Add => { - n.checked_add(m).map(Type::IntLiteral).unwrap_or_else(|| { - builtins_symbol_ty_by_name(self.db, "int").instance() - }) - } - ast::Operator::Sub => { - n.checked_sub(m).map(Type::IntLiteral).unwrap_or_else(|| { - builtins_symbol_ty_by_name(self.db, "int").instance() - }) - } - ast::Operator::Mult => { - n.checked_mul(m).map(Type::IntLiteral).unwrap_or_else(|| { - builtins_symbol_ty_by_name(self.db, "int").instance() - }) - } - ast::Operator::Div => { - n.checked_div(m).map(Type::IntLiteral).unwrap_or_else(|| { - builtins_symbol_ty_by_name(self.db, "int").instance() - }) - } - ast::Operator::Mod => n - .checked_rem(m) - .map(Type::IntLiteral) - // TODO division by zero error - .unwrap_or(Type::Unknown), - _ => Type::Unknown, // TODO - } - } - _ => Type::Unknown, // TODO - } + match (left_ty, right_ty, op) { + (Type::Any, _, _) | (_, Type::Any, _) => Type::Any, + (Type::Unknown, _, _) | (_, Type::Unknown, _) => Type::Unknown, + + (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n + .checked_add(m) + .map(Type::IntLiteral) + .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), + + (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n + .checked_sub(m) + .map(Type::IntLiteral) + .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), + + (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n + .checked_mul(m) + .map(Type::IntLiteral) + .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), + + (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n + .checked_div(m) + .map(Type::IntLiteral) + .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), + + (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n + .checked_rem(m) + .map(Type::IntLiteral) + // TODO division by zero error + .unwrap_or(Type::Unknown), + + (Type::BytesLiteral(lhs), Type::BytesLiteral(rhs), ast::Operator::Add) => { + Type::BytesLiteral(BytesLiteralType::new( + self.db, + [lhs.value(self.db).as_ref(), rhs.value(self.db).as_ref()] + .concat() + .into_boxed_slice(), + )) } - Type::BytesLiteral(lhs) => { - match right_ty { - Type::BytesLiteral(rhs) => { - match op { - ast::Operator::Add => Type::BytesLiteral(BytesLiteralType::new( - self.db, - [lhs.value(self.db).as_ref(), rhs.value(self.db).as_ref()] - .concat() - .into_boxed_slice(), - )), - _ => Type::Unknown, // TODO - } - } - _ => Type::Unknown, // TODO - } + + (Type::StringLiteral(lhs), Type::StringLiteral(rhs), ast::Operator::Add) => { + Type::StringLiteral(StringLiteralType::new(self.db, { + let lhs_value = lhs.value(self.db); + let rhs_value = rhs.value(self.db); + lhs_value.clone() + rhs_value + })) } - Type::StringLiteral(lhs) => match right_ty { - Type::StringLiteral(rhs) => match op { - ast::Operator::Add => Type::StringLiteral(StringLiteralType::new(self.db, { - let lhs_value = lhs.value(self.db); - let rhs_value = rhs.value(self.db); - lhs_value.clone() + rhs_value - })), - _ => Type::Unknown, // TODO - }, - _ => Type::Unknown, // TODO - }, _ => Type::Unknown, // TODO } }