From fa7626160bd89ff76ac877ee9c71e15c5bad3ff8 Mon Sep 17 00:00:00 2001 From: David Peter Date: Mon, 21 Oct 2024 20:12:03 +0200 Subject: [PATCH] [red-knot] handle unions on the LHS of is_subtype_of (#13857) ## Summary Just a drive-by change that occurred to me while I was looking at `Type::is_subtype_of`: the existing pattern for unions on the *right hand side*: ```rs (ty, Type::Union(union)) => union .elements(db) .iter() .any(|&elem_ty| ty.is_subtype_of(db, elem_ty)), ``` is not (generally) correct if the *left hand side* is a union. ## Test Plan Added new test cases for `is_subtype_of` and `!is_subtype_of` --- crates/red_knot_python_semantic/src/types.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index dd66bf173b..22718df737 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -434,6 +434,10 @@ impl<'db> Type<'db> { { true } + (Type::Union(union), ty) => union + .elements(db) + .iter() + .all(|&elem_ty| elem_ty.is_subtype_of(db, ty)), (ty, Type::Union(union)) => union .elements(db) .iter() @@ -1839,6 +1843,8 @@ mod tests { #[test_case(Ty::LiteralString, Ty::BuiltinInstance("str"))] #[test_case(Ty::BytesLiteral("foo"), Ty::BuiltinInstance("bytes"))] #[test_case(Ty::IntLiteral(1), Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")]))] + #[test_case(Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]), Ty::BuiltinInstance("object"))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2), Ty::IntLiteral(3)]))] #[test_case(Ty::BuiltinInstance("TypeError"), Ty::BuiltinInstance("Exception"))] fn is_subtype_of(from: Ty, to: Ty) { let db = setup_db(); @@ -1852,6 +1858,8 @@ mod tests { #[test_case(Ty::IntLiteral(1), Ty::Any)] #[test_case(Ty::IntLiteral(1), Ty::Union(vec![Ty::Unknown, Ty::BuiltinInstance("str")]))] #[test_case(Ty::IntLiteral(1), Ty::BuiltinInstance("str"))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::IntLiteral(1))] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(3)]))] #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))] #[test_case(Ty::BuiltinInstance("int"), Ty::IntLiteral(1))] fn is_not_subtype_of(from: Ty, to: Ty) {