diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b1e762a95a..6809ae8467 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -13916,40 +13916,13 @@ impl<'db> UnionType<'db> { ConstraintSet::from(sorted_self == other.normalized(db)) } - /// Returns true if this union is equivalent to `int | float`, which is what `float` expands - /// into in type position. - pub(crate) fn is_int_float(self, db: &'db dyn Db) -> bool { - let elements = self.elements(db); - if elements.len() != 2 { - return false; - } - let mut has_int = false; - let mut has_float = false; - for element in elements { - if let Type::NominalInstance(nominal) = element - && let Some(known) = nominal.known_class(db) - { - match known { - KnownClass::Int => has_int = true, - KnownClass::Float => has_float = true, - _ => {} - } - } - } - has_int && has_float - } - /// Returns true if this union is equivalent to `int | float | complex`, which is what /// `complex` expands into in type position. - pub(crate) fn is_int_float_complex(self, db: &'db dyn Db) -> bool { - let elements = self.elements(db); - if elements.len() != 3 { - return false; - } + pub(crate) fn known(self, db: &'db dyn Db) -> Option { let mut has_int = false; let mut has_float = false; let mut has_complex = false; - for element in elements { + for element in self.elements(db) { if let Type::NominalInstance(nominal) = element && let Some(known) = nominal.known_class(db) { @@ -13957,14 +13930,26 @@ impl<'db> UnionType<'db> { KnownClass::Int => has_int = true, KnownClass::Float => has_float = true, KnownClass::Complex => has_complex = true, - _ => {} + _ => return None, } + } else { + return None; } } - has_int && has_float && has_complex + match (has_int, has_float, has_complex) { + (true, true, false) => Some(KnownUnion::Float), + (true, true, true) => Some(KnownUnion::Complex), + _ => None, + } } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum KnownUnion { + Float, // `int | float` + Complex, // `int | float | complex` +} + #[salsa::interned(debug, heap_size=IntersectionType::heap_size)] pub struct IntersectionType<'db> { /// The intersection type includes only values in all of these types. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 8cdf3277ca..75b1f6eb80 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -104,13 +104,13 @@ use crate::types::visitor::any_over_type; use crate::types::{ BoundTypeVarInstance, CallDunderError, CallableBinding, CallableType, CallableTypeKind, ClassLiteral, ClassType, DataclassParams, DynamicType, InternedType, IntersectionBuilder, - IntersectionType, KnownClass, KnownInstanceType, LintDiagnosticGuard, MemberLookupPolicy, - MetaclassCandidate, PEP695TypeAliasType, ParamSpecAttrKind, Parameter, ParameterForm, - Parameters, Signature, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, TypeVarBoundOrConstraints, - TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarIdentity, - TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, UnionType, - UnionTypeInstance, binding_type, infer_scope_types, todo_type, + IntersectionType, KnownClass, KnownInstanceType, KnownUnion, LintDiagnosticGuard, + MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, ParamSpecAttrKind, Parameter, + ParameterForm, Parameters, Signature, SpecialFormType, SubclassOfType, TrackedConstraintSet, + Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, + TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, + TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, + UnionType, UnionTypeInstance, binding_type, infer_scope_types, todo_type, }; use crate::types::{CallableTypes, overrides}; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; @@ -5629,35 +5629,34 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Infer the deferred base type of a NewType. fn infer_newtype_assignment_deferred(&mut self, arguments: &ast::Arguments) { - match self.infer_type_expression(&arguments.args[1]) { - Type::NominalInstance(_) | Type::NewTypeInstance(_) => {} + let inferred = self.infer_type_expression(&arguments.args[1]); + match inferred { + Type::NominalInstance(_) | Type::NewTypeInstance(_) => return, // There are exactly two union types allowed as bases for NewType: `int | float` and // `int | float | complex`. These are allowed because that's what `float` and `complex` // expand into in type position. We don't currently ask whether the union was implicit // or explicit, so the explicit version is also allowed. - Type::Union(union_ty) - if union_ty.is_int_float(self.db()) || union_ty.is_int_float_complex(self.db()) => { - } + Type::Union(union_ty) => match union_ty.known(self.db()) { + Some(KnownUnion::Float) | Some(KnownUnion::Complex) => return, + _ => {} + }, // `Unknown` is likely to be the result of an unresolved import or a typo, which will // already get a diagnostic, so don't pile on an extra diagnostic here. - Type::Dynamic(DynamicType::Unknown) => {} - other_type => { - if let Some(builder) = self - .context - .report_lint(&INVALID_NEWTYPE, &arguments.args[1]) - { - let mut diag = builder.into_diagnostic("invalid base for `typing.NewType`"); - diag.set_primary_message(format!("type `{}`", other_type.display(self.db()))); - if matches!(other_type, Type::ProtocolInstance(_)) { - diag.info("The base of a `NewType` is not allowed to be a protocol class."); - } else if matches!(other_type, Type::TypedDict(_)) { - diag.info("The base of a `NewType` is not allowed to be a `TypedDict`."); - } else { - diag.info( - "The base of a `NewType` must be a class type or another `NewType`.", - ); - } - } + Type::Dynamic(DynamicType::Unknown) => return, + _ => {} + } + if let Some(builder) = self + .context + .report_lint(&INVALID_NEWTYPE, &arguments.args[1]) + { + let mut diag = builder.into_diagnostic("invalid base for `typing.NewType`"); + diag.set_primary_message(format!("type `{}`", inferred.display(self.db()))); + if matches!(inferred, Type::ProtocolInstance(_)) { + diag.info("The base of a `NewType` is not allowed to be a protocol class."); + } else if matches!(inferred, Type::TypedDict(_)) { + diag.info("The base of a `NewType` is not allowed to be a `TypedDict`."); + } else { + diag.info("The base of a `NewType` must be a class type or another `NewType`."); } } } diff --git a/crates/ty_python_semantic/src/types/newtype.rs b/crates/ty_python_semantic/src/types/newtype.rs index 220e7e7230..cc6f2cff69 100644 --- a/crates/ty_python_semantic/src/types/newtype.rs +++ b/crates/ty_python_semantic/src/types/newtype.rs @@ -3,7 +3,9 @@ use std::collections::BTreeSet; use crate::Db; use crate::semantic_index::definition::{Definition, DefinitionKind}; use crate::types::constraints::ConstraintSet; -use crate::types::{ClassType, KnownClass, Type, UnionType, definition_expression_type, visitor}; +use crate::types::{ + ClassType, KnownClass, KnownUnion, Type, UnionType, definition_expression_type, visitor, +}; use ruff_db::parsed::parsed_module; use ruff_python_ast as ast; @@ -84,8 +86,11 @@ impl<'db> NewType<'db> { // `int | float | complex`. These are allowed because that's what `float` and `complex` // expand into in type position. We don't currently ask whether the union was implicit // or explicit, so the explicit version is also allowed. - Type::Union(union_type) if union_type.is_int_float(db) => NewTypeBase::Float, - Type::Union(union_type) if union_type.is_int_float_complex(db) => NewTypeBase::Complex, + Type::Union(union_type) => match union_type.known(db) { + Some(KnownUnion::Float) => NewTypeBase::Float, + Some(KnownUnion::Complex) => NewTypeBase::Complex, + _ => object_fallback, + }, _ => object_fallback, } }