[ty] Support type context of union attribute assignments (#21170)

## Summary

Turns out this is easy to implement. Resolves
https://github.com/astral-sh/ty/issues/1375.
This commit is contained in:
Ibraheem Ahmed 2025-10-31 12:41:14 -04:00 committed by GitHub
parent 9664474c51
commit ff3a6a8fbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 7 deletions

View File

@ -200,7 +200,7 @@ def f() -> list[Literal[1]]:
return [1] return [1]
``` ```
## Instance attribute ## Instance attributes
```toml ```toml
[environment] [environment]
@ -235,6 +235,24 @@ def _(flag: bool):
C.x = lst(1) C.x = lst(1)
``` ```
For union targets, each element of the union is considered as a separate type context:
```py
from typing import Literal
class X:
x: list[int | str]
class Y:
x: list[int | None]
def lst[T](x: T) -> list[T]:
return [x]
def _(xy: X | Y):
xy.x = lst(1)
```
## Class constructor parameters ## Class constructor parameters
```toml ```toml

View File

@ -3574,7 +3574,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target: &ast::ExprAttribute, target: &ast::ExprAttribute,
object_ty: Type<'db>, object_ty: Type<'db>,
attribute: &str, attribute: &str,
infer_value_ty: &dyn Fn(&mut Self, TypeContext<'db>) -> Type<'db>, infer_value_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
emit_diagnostics: bool, emit_diagnostics: bool,
) -> bool { ) -> bool {
let db = self.db(); let db = self.db();
@ -3651,7 +3651,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match object_ty { match object_ty {
Type::Union(union) => { Type::Union(union) => {
// TODO: We could perform multi-inference here with each element of the union as type context. // First infer the value without type context, and then again for each union element.
let value_ty = infer_value_ty(self, TypeContext::default()); let value_ty = infer_value_ty(self, TypeContext::default());
if union.elements(self.db()).iter().all(|elem| { if union.elements(self.db()).iter().all(|elem| {
@ -3659,7 +3659,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target, target,
*elem, *elem,
attribute, attribute,
&|_, _| value_ty, // Note that `infer_value_ty` silences diagnostics after the first inference.
&mut infer_value_ty,
false, false,
) )
}) { }) {
@ -3684,7 +3685,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
Type::Intersection(intersection) => { Type::Intersection(intersection) => {
// TODO: We could perform multi-inference here with each element of the union as type context. // First infer the value without type context, and then again for each union element.
let value_ty = infer_value_ty(self, TypeContext::default()); let value_ty = infer_value_ty(self, TypeContext::default());
// TODO: Handle negative intersection elements // TODO: Handle negative intersection elements
@ -3693,7 +3694,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target, target,
*elem, *elem,
attribute, attribute,
&|_, _| value_ty, // Note that `infer_value_ty` silences diagnostics after the first inference.
&mut infer_value_ty,
false, false,
) )
}) { }) {
@ -4254,7 +4256,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let object_ty = self.infer_expression(object, TypeContext::default()); let object_ty = self.infer_expression(object, TypeContext::default());
if let Some(infer_assigned_ty) = infer_assigned_ty { if let Some(infer_assigned_ty) = infer_assigned_ty {
let infer_assigned_ty = &|builder: &mut Self, tcx| { let infer_assigned_ty = &mut |builder: &mut Self, tcx| {
let assigned_ty = infer_assigned_ty(builder, tcx); let assigned_ty = infer_assigned_ty(builder, tcx);
builder.store_expression_type(target, assigned_ty); builder.store_expression_type(target, assigned_ty);
assigned_ty assigned_ty