[ty] Eagerly evaluate `types.UnionType` elements as type expressions (#21531)

## Summary

Eagerly evaluate the elements of a PEP 604 union in value position (e.g.
`IntOrStr = int | str`) as type expressions and store the result (the
corresponding `Type::Union` if all elements are valid type expressions,
or the first encountered `InvalidTypeExpressionError`) on the
`UnionTypeInstance`, such that the `Type::Union(…)` does not need to be
recomputed every time the implicit type alias is used in a type
annotation.

This might lead to performance improvements for large unions, but is
also necessary for correctness, because the elements of the union might
refer to type variables that need to be looked up in the scope of the
type alias, not at the usage site.

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-11-20 17:28:48 +01:00 committed by GitHub
parent 416e2267da
commit 0761ea42d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 200 additions and 100 deletions

View File

@ -191,13 +191,13 @@ def _(
reveal_type(int_or_callable) # revealed: int | ((str, /) -> bytes)
reveal_type(callable_or_int) # revealed: ((str, /) -> bytes) | int
# TODO should be Unknown | int
reveal_type(type_var_or_int) # revealed: T@_ | int
reveal_type(type_var_or_int) # revealed: typing.TypeVar | int
# TODO should be int | Unknown
reveal_type(int_or_type_var) # revealed: int | T@_
reveal_type(int_or_type_var) # revealed: int | typing.TypeVar
# TODO should be Unknown | None
reveal_type(type_var_or_none) # revealed: T@_ | None
reveal_type(type_var_or_none) # revealed: typing.TypeVar | None
# TODO should be None | Unknown
reveal_type(none_or_type_var) # revealed: None | T@_
reveal_type(none_or_type_var) # revealed: None | typing.TypeVar
```
If a type is unioned with itself in a value expression, the result is just that type. No

View File

@ -159,19 +159,43 @@ IntOrStr = Union[int, str]
reveal_type(IntOrStr) # revealed: types.UnionType
def _(x: int | str | bytes | memoryview | range):
# TODO: no error
# error: [invalid-argument-type]
if isinstance(x, IntOrStr):
# TODO: Should be `int | str`
reveal_type(x) # revealed: int | str | bytes | memoryview[int] | range
# TODO: no error
# error: [invalid-argument-type]
reveal_type(x) # revealed: int | str
elif isinstance(x, Union[bytes, memoryview]):
# TODO: Should be `bytes | memoryview[int]`
reveal_type(x) # revealed: int | str | bytes | memoryview[int] | range
reveal_type(x) # revealed: bytes | memoryview[int]
else:
# TODO: Should be `range`
reveal_type(x) # revealed: int | str | bytes | memoryview[int] | range
reveal_type(x) # revealed: range
def _(x: int | str | None):
if isinstance(x, Union[int, None]):
reveal_type(x) # revealed: int | None
else:
reveal_type(x) # revealed: str
ListStrOrInt = Union[list[str], int]
def _(x: dict[int, str] | ListStrOrInt):
# TODO: this should ideally be an error
if isinstance(x, ListStrOrInt):
# TODO: this should not be narrowed
reveal_type(x) # revealed: list[str] | int
# TODO: this should ideally be an error
if isinstance(x, Union[list[str], int]):
# TODO: this should not be narrowed
reveal_type(x) # revealed: list[str] | int
```
## `Optional` as `classinfo`
```py
from typing import Optional
def _(x: int | str | None):
if isinstance(x, Optional[int]):
reveal_type(x) # revealed: int | None
else:
reveal_type(x) # revealed: str
```
## `classinfo` is a `typing.py` special form
@ -289,6 +313,23 @@ def _(flag: bool):
reveal_type(x) # revealed: Literal[1, "a"]
```
## Generic aliases are not supported as second argument
The `classinfo` argument cannot be a generic alias:
```py
def _(x: list[str] | list[int] | list[bytes]):
# TODO: Ideally, this would be an error (requires https://github.com/astral-sh/ty/issues/116)
if isinstance(x, list[int]):
# No narrowing here:
reveal_type(x) # revealed: list[str] | list[int] | list[bytes]
# error: [invalid-argument-type] "Invalid second argument to `isinstance`"
if isinstance(x, list[int] | list[str]):
# No narrowing here:
reveal_type(x) # revealed: list[str] | list[int] | list[bytes]
```
## `type[]` types are narrowed as well as class-literal types
```py

View File

@ -212,19 +212,12 @@ IntOrStr = Union[int, str]
reveal_type(IntOrStr) # revealed: types.UnionType
def f(x: type[int | str | bytes | range]):
# TODO: No error
# error: [invalid-argument-type]
if issubclass(x, IntOrStr):
# TODO: Should be `type[int] | type[str]`
reveal_type(x) # revealed: type[int] | type[str] | type[bytes] | <class 'range'>
# TODO: No error
# error: [invalid-argument-type]
reveal_type(x) # revealed: type[int] | type[str]
elif issubclass(x, Union[bytes, memoryview]):
# TODO: Should be `type[bytes]`
reveal_type(x) # revealed: type[int] | type[str] | type[bytes] | <class 'range'>
reveal_type(x) # revealed: type[bytes]
else:
# TODO: Should be `<class 'range'>`
reveal_type(x) # revealed: type[int] | type[str] | type[bytes] | <class 'range'>
reveal_type(x) # revealed: <class 'range'>
```
## Special cases

View File

@ -6738,17 +6738,10 @@ impl<'db> Type<'db> {
invalid_expressions: smallvec::smallvec_inline![InvalidTypeExpression::Generic],
fallback_type: Type::unknown(),
}),
KnownInstanceType::UnionType(list) => {
let mut builder = UnionBuilder::new(db);
let inferred_as = list.inferred_as(db);
for element in list.elements(db) {
builder = builder.add(if inferred_as.type_expression() {
*element
} else {
element.in_type_expression(db, scope_id, typevar_binding_context)?
});
}
Ok(builder.build())
KnownInstanceType::UnionType(instance) => {
// Cloning here is cheap if the result is a `Type` (which is `Copy`). It's more
// expensive if there are errors.
instance.union_type(db).clone()
}
KnownInstanceType::Literal(ty) => Ok(ty.inner(db)),
KnownInstanceType::Annotated(ty) => Ok(ty.inner(db)),
@ -8004,9 +7997,9 @@ pub enum KnownInstanceType<'db> {
/// `ty_extensions.Specialization`.
Specialization(Specialization<'db>),
/// A single instance of `types.UnionType`, which stores the left- and
/// right-hand sides of a PEP 604 union.
UnionType(InternedTypes<'db>),
/// A single instance of `types.UnionType`, which stores the elements of
/// a PEP 604 union, or a `typing.Union`.
UnionType(UnionTypeInstance<'db>),
/// A single instance of `typing.Literal`
Literal(InternedType<'db>),
@ -8052,9 +8045,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
visitor.visit_type(db, default_ty);
}
}
KnownInstanceType::UnionType(list) => {
for element in list.elements(db) {
visitor.visit_type(db, *element);
KnownInstanceType::UnionType(instance) => {
if let Ok(union_type) = instance.union_type(db) {
visitor.visit_type(db, *union_type);
}
}
KnownInstanceType::Literal(ty)
@ -8098,7 +8091,7 @@ impl<'db> KnownInstanceType<'db> {
Self::TypeAliasType(type_alias.normalized_impl(db, visitor))
}
Self::Field(field) => Self::Field(field.normalized_impl(db, visitor)),
Self::UnionType(list) => Self::UnionType(list.normalized_impl(db, visitor)),
Self::UnionType(instance) => Self::UnionType(instance.normalized_impl(db, visitor)),
Self::Literal(ty) => Self::Literal(ty.normalized_impl(db, visitor)),
Self::Annotated(ty) => Self::Annotated(ty.normalized_impl(db, visitor)),
Self::TypeGenericAlias(ty) => Self::TypeGenericAlias(ty.normalized_impl(db, visitor)),
@ -8430,7 +8423,7 @@ impl<'db> TypeAndQualifiers<'db> {
/// Error struct providing information on type(s) that were deemed to be invalid
/// in a type expression context, and the type we should therefore fallback to
/// for the problematic type expression.
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub struct InvalidTypeExpressionError<'db> {
fallback_type: Type<'db>,
invalid_expressions: smallvec::SmallVec<[InvalidTypeExpression<'db>; 1]>,
@ -8461,7 +8454,7 @@ impl<'db> InvalidTypeExpressionError<'db> {
}
/// Enumeration of various types that are invalid in type-expression contexts
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize)]
enum InvalidTypeExpression<'db> {
/// Some types always require exactly one argument when used in a type expression
RequiresOneArgument(Type<'db>),
@ -9399,39 +9392,106 @@ impl InferredAs {
}
}
/// A salsa-interned list of types.
/// Contains information about a `types.UnionType` instance built from a PEP 604
/// union or a legacy `typing.Union[…]` annotation in a value expression context,
/// e.g. `IntOrStr = int | str` or `IntOrStr = Union[int, str]`.
///
/// # Ordering
/// Ordering is based on the context's salsa-assigned id and not on its values.
/// The id may change between runs, or when the context was garbage collected and recreated.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub struct InternedTypes<'db> {
#[returns(deref)]
elements: Box<[Type<'db>]>,
inferred_as: InferredAs,
pub struct UnionTypeInstance<'db> {
/// The types of the elements of this union, as they were inferred in a value
/// expression context. For `int | str`, this would contain `<class 'int'>` and
/// `<class 'str'>`. For `Union[int, str]`, this field is `None`, as we infer
/// the elements as type expressions. Use `value_expression_types` to get the
/// corresponding value expression types.
#[expect(clippy::ref_option)]
#[returns(ref)]
_value_expr_types: Option<Box<[Type<'db>]>>,
/// The type of the full union, which can be used when this `UnionType` instance
/// is used in a type expression context. For `int | str`, this would contain
/// `Ok(int | str)`. If any of the element types could not be converted, this
/// contains the first encountered error.
#[returns(ref)]
union_type: Result<Type<'db>, InvalidTypeExpressionError<'db>>,
}
impl get_size2::GetSize for InternedTypes<'_> {}
impl get_size2::GetSize for UnionTypeInstance<'_> {}
impl<'db> InternedTypes<'db> {
pub(crate) fn from_elements(
impl<'db> UnionTypeInstance<'db> {
pub(crate) fn from_value_expression_types(
db: &'db dyn Db,
elements: impl IntoIterator<Item = Type<'db>>,
inferred_as: InferredAs,
) -> InternedTypes<'db> {
InternedTypes::new(db, elements.into_iter().collect::<Box<[_]>>(), inferred_as)
value_expr_types: impl IntoIterator<Item = Type<'db>>,
scope_id: ScopeId<'db>,
typevar_binding_context: Option<Definition<'db>>,
) -> Type<'db> {
let value_expr_types = value_expr_types.into_iter().collect::<Box<_>>();
let mut builder = UnionBuilder::new(db);
for ty in &value_expr_types {
match ty.in_type_expression(db, scope_id, typevar_binding_context) {
Ok(ty) => builder.add_in_place(ty),
Err(error) => {
return Type::KnownInstance(KnownInstanceType::UnionType(
UnionTypeInstance::new(db, Some(value_expr_types), Err(error)),
));
}
}
}
Type::KnownInstance(KnownInstanceType::UnionType(UnionTypeInstance::new(
db,
Some(value_expr_types),
Ok(builder.build()),
)))
}
/// Get the types of the elements of this union as they would appear in a value
/// expression context. For a PEP 604 union, we return the actual types that were
/// inferred when we encountered the union in a value expression context. For a
/// legacy `typing.Union[…]` annotation, we turn the type-expression types into
/// their corresponding value-expression types, i.e. we turn instances like `int`
/// into class literals like `<class 'int'>`. This operation is potentially lossy.
pub(crate) fn value_expression_types(
self,
db: &'db dyn Db,
) -> Result<impl Iterator<Item = Type<'db>> + 'db, InvalidTypeExpressionError<'db>> {
let to_class_literal = |ty: Type<'db>| {
ty.as_nominal_instance()
.map(|instance| Type::ClassLiteral(instance.class(db).class_literal(db).0))
.unwrap_or_else(Type::unknown)
};
if let Some(value_expr_types) = self._value_expr_types(db) {
Ok(Either::Left(value_expr_types.iter().copied()))
} else {
match self.union_type(db).clone()? {
Type::Union(union) => Ok(Either::Right(Either::Left(
union.elements(db).iter().copied().map(to_class_literal),
))),
ty => Ok(Either::Right(Either::Right(std::iter::once(
to_class_literal(ty),
)))),
}
}
}
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
InternedTypes::new(
db,
self.elements(db)
let value_expr_types = self._value_expr_types(db).as_ref().map(|types| {
types
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect::<Box<[_]>>(),
self.inferred_as(db),
)
.collect::<Box<_>>()
});
let union_type = self
.union_type(db)
.clone()
.map(|ty| ty.normalized_impl(db, visitor));
Self::new(db, value_expr_types, union_type)
}
}

View File

@ -1790,16 +1790,23 @@ impl KnownFunction {
// `Any` can be used in `issubclass()` calls but not `isinstance()` calls
Type::SpecialForm(SpecialFormType::Any)
if function == KnownFunction::IsSubclass => {}
Type::KnownInstance(KnownInstanceType::UnionType(union)) => {
for element in union.elements(db) {
Type::KnownInstance(KnownInstanceType::UnionType(instance)) => {
match instance.value_expression_types(db) {
Ok(value_expression_types) => {
for element in value_expression_types {
find_invalid_elements(
db,
function,
*element,
element,
invalid_elements,
);
}
}
Err(_) => {
invalid_elements.push(ty);
}
}
}
_ => invalid_elements.push(ty),
}
}

View File

@ -102,13 +102,13 @@ use crate::types::typed_dict::{
use crate::types::visitor::any_over_type;
use crate::types::{
CallDunderError, CallableBinding, CallableType, ClassLiteral, ClassType, DataclassParams,
DynamicType, InferredAs, InternedType, InternedTypes, IntersectionBuilder, IntersectionType,
KnownClass, KnownInstanceType, LintDiagnosticGuard, MemberLookupPolicy, MetaclassCandidate,
DynamicType, InternedType, IntersectionBuilder, IntersectionType, KnownClass,
KnownInstanceType, LintDiagnosticGuard, MemberLookupPolicy, MetaclassCandidate,
PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, SubclassOfType,
TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext,
TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarIdentity,
TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, UnionType,
binding_type, todo_type,
UnionTypeInstance, binding_type, todo_type,
};
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
use crate::unpack::{EvaluationMode, UnpackPosition};
@ -9545,13 +9545,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if left_ty.is_equivalent_to(self.db(), right_ty) {
Some(left_ty)
} else {
Some(Type::KnownInstance(KnownInstanceType::UnionType(
InternedTypes::from_elements(
Some(UnionTypeInstance::from_value_expression_types(
self.db(),
[left_ty, right_ty],
InferredAs::ValueExpression,
),
)))
self.scope(),
self.typevar_binding_context,
))
}
}
(
@ -9574,13 +9573,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) if pep_604_unions_allowed()
&& instance.has_known_class(self.db(), KnownClass::NoneType) =>
{
Some(Type::KnownInstance(KnownInstanceType::UnionType(
InternedTypes::from_elements(
Some(UnionTypeInstance::from_value_expression_types(
self.db(),
[left_ty, right_ty],
InferredAs::ValueExpression,
),
)))
self.scope(),
self.typevar_binding_context,
))
}
// We avoid calling `type.__(r)or__`, as typeshed annotates these methods as
@ -10801,13 +10799,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return ty;
}
return Type::KnownInstance(KnownInstanceType::UnionType(
InternedTypes::from_elements(
return UnionTypeInstance::from_value_expression_types(
self.db(),
[ty, Type::none(self.db())],
InferredAs::ValueExpression,
),
));
self.scope(),
self.typevar_binding_context,
);
}
Type::SpecialForm(SpecialFormType::Union) => {
let db = self.db();
@ -10822,7 +10819,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let is_empty = elements.peek().is_none();
let union_type = Type::KnownInstance(KnownInstanceType::UnionType(
InternedTypes::from_elements(db, elements, InferredAs::TypeExpression),
UnionTypeInstance::new(
db,
None,
Ok(UnionType::from_elements(db, elements)),
),
));
if is_empty {

View File

@ -212,10 +212,10 @@ impl ClassInfoConstraintFunction {
)
}),
Type::KnownInstance(KnownInstanceType::UnionType(elements)) => {
Type::KnownInstance(KnownInstanceType::UnionType(instance)) => {
UnionType::try_from_elements(
db,
elements.elements(db).iter().map(|element| {
instance.value_expression_types(db).ok()?.map(|element| {
// A special case is made for `None` at runtime
// (it's implicitly converted to `NoneType` in `int | None`)
// which means that `isinstance(x, int | None)` works even though
@ -223,7 +223,7 @@ impl ClassInfoConstraintFunction {
if element.is_none(db) {
self.generate_constraint(db, KnownClass::NoneType.to_class_literal(db))
} else {
self.generate_constraint(db, *element)
self.generate_constraint(db, element)
}
}),
)
@ -874,8 +874,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let callable_ty = inference.expression_type(&*expr_call.func);
// TODO: add support for PEP 604 union types on the right hand side of `isinstance`
// and `issubclass`, for example `isinstance(x, str | (int | float))`.
match callable_ty {
Type::FunctionLiteral(function_type)
if matches!(