diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/literal.md b/crates/ty_python_semantic/resources/mdtest/annotations/literal.md index 05b0868523..0c266868e1 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/literal.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/literal.md @@ -25,6 +25,9 @@ class Color(Enum): b1: Literal[Color.RED] +MissingT = Enum("MissingT", {"MISSING": "MISSING"}) +b2: Literal[MissingT.MISSING] + def f(): reveal_type(mode) # revealed: Literal["w", "r"] reveal_type(a1) # revealed: Literal[26] @@ -51,6 +54,12 @@ invalid4: Literal[ hello, # error: [invalid-type-form] (1, 2, 3), # error: [invalid-type-form] ] + +class NotAnEnum: + x: int = 1 + +# error: [invalid-type-form] +invalid5: Literal[NotAnEnum.x] ``` ## Shortening unions of literals diff --git a/crates/ty_python_semantic/src/types/enums.rs b/crates/ty_python_semantic/src/types/enums.rs index 59c814c147..5c026beff5 100644 --- a/crates/ty_python_semantic/src/types/enums.rs +++ b/crates/ty_python_semantic/src/types/enums.rs @@ -240,3 +240,10 @@ pub(crate) fn enum_member_literals<'a, 'db: 'a>( pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool { enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1) } + +pub(crate) fn is_enum_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { + match ty { + Type::ClassLiteral(class_literal) => enum_metadata(db, class_literal).is_some(), + _ => false, + } +} diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b0ab2b8492..caffb49aee 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -104,6 +104,7 @@ use crate::types::diagnostic::{ report_invalid_generator_function_return_type, report_invalid_return_type, report_possibly_unbound_attribute, }; +use crate::types::enums::is_enum_class; use crate::types::function::{ FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral, }; @@ -10414,14 +10415,23 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // For enum values ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { let value_ty = self.infer_expression(value); - // TODO: Check that value type is enum otherwise return None - let ty = value_ty - .member(self.db(), &attr.id) - .place - .ignore_possibly_unbound() - .unwrap_or(Type::unknown()); - self.store_expression_type(parameters, ty); - ty + + if is_enum_class(self.db(), value_ty) { + let ty = value_ty + .member(self.db(), &attr.id) + .place + .ignore_possibly_unbound() + .unwrap_or(Type::unknown()); + self.store_expression_type(parameters, ty); + ty + } else { + self.store_expression_type(parameters, Type::unknown()); + if value_ty.is_todo() { + value_ty + } else { + return Err(vec![parameters]); + } + } } // for negative and positive numbers ast::Expr::UnaryOp(u)