[ty] Temporary SQLAlchemy special-case

This commit is contained in:
David Peter 2025-12-09 10:34:09 +01:00
parent 4e67a219bb
commit 352628e986
4 changed files with 145 additions and 34 deletions

View File

@ -106,45 +106,36 @@ reveal_type(admin_users) # revealed: Sequence[User]
We can also specify particular columns to select: We can also specify particular columns to select:
```py ```py
reveal_type(User.id) # revealed: InstrumentedAttribute[int]
stmt = select(User.id, User.name) stmt = select(User.id, User.name)
# TODO: should be `Select[tuple[int, str]]` reveal_type(stmt) # revealed: Select[tuple[int, str]]
reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]]
ids_and_names = session.execute(stmt).all() ids_and_names = session.execute(stmt).all()
# TODO: should be `Sequence[Row[tuple[int, str]]]` reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]]
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]]
for row in session.execute(stmt): for row in session.execute(stmt):
# TODO: should be `Row[tuple[int, str]]` reveal_type(row) # revealed: Row[tuple[int, str]]
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
for user_id, name in session.execute(stmt).tuples(): for user_id, name in session.execute(stmt).tuples():
# TODO: should be `int` reveal_type(user_id) # revealed: int
reveal_type(user_id) # revealed: Unknown reveal_type(name) # revealed: str
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
result = session.execute(stmt) result = session.execute(stmt)
row = result.one_or_none() row = result.one_or_none()
assert row is not None assert row is not None
(user_id, name) = row._tuple() (user_id, name) = row._tuple()
# TODO: should be `int` reveal_type(user_id) # revealed: int
reveal_type(user_id) # revealed: Unknown reveal_type(name) # revealed: str
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
stmt = select(User.id).where(User.name == "Alice") stmt = select(User.id).where(User.name == "Alice")
# TODO: should be `Select[tuple[int]]` reveal_type(stmt) # revealed: Select[tuple[int]]
reveal_type(stmt) # revealed: Select[tuple[Unknown]]
alice_id = session.scalars(stmt).first() alice_id = session.scalars(stmt).first()
# TODO: should be `int | None` reveal_type(alice_id) # revealed: int | None
reveal_type(alice_id) # revealed: Unknown | None
alice_id = session.scalar(stmt) alice_id = session.scalar(stmt)
# TODO: should be `int | None` reveal_type(alice_id) # revealed: int | None
reveal_type(alice_id) # revealed: Unknown | None
``` ```
Using the legacy `query` API also works: Using the legacy `query` API also works:
@ -203,8 +194,6 @@ async def test_async(session: AsyncSession):
stmt = select(User.id, User.name) stmt = select(User.id, User.name)
result = await session.execute(stmt) result = await session.execute(stmt)
for user_id, name in result.tuples(): for user_id, name in result.tuples():
# TODO: should be `int` reveal_type(user_id) # revealed: int
reveal_type(user_id) # revealed: Unknown reveal_type(name) # revealed: str
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
``` ```

View File

@ -335,6 +335,12 @@ pub enum KnownModule {
#[cfg(test)] #[cfg(test)]
Uuid, Uuid,
Warnings, Warnings,
#[strum(serialize = "sqlalchemy.sql.selectable")]
SqlalchemySqlSelectable,
#[strum(serialize = "sqlalchemy.sql._selectable_constructors")]
SqlalchemySqlSelectableConstructors,
#[strum(serialize = "sqlalchemy.orm.attributes")]
SqlalchemyOrmAttributes,
} }
impl KnownModule { impl KnownModule {
@ -363,6 +369,9 @@ impl KnownModule {
#[cfg(test)] #[cfg(test)]
Self::Uuid => "uuid", Self::Uuid => "uuid",
Self::Templatelib => "string.templatelib", Self::Templatelib => "string.templatelib",
Self::SqlalchemySqlSelectable => "sqlalchemy.sql.selectable",
Self::SqlalchemySqlSelectableConstructors => "sqlalchemy.sql._selectable_constructors",
Self::SqlalchemyOrmAttributes => "sqlalchemy.orm.attributes",
} }
} }
@ -378,7 +387,20 @@ impl KnownModule {
if search_path.is_standard_library() { if search_path.is_standard_library() {
Self::from_str(name.as_str()).ok() Self::from_str(name.as_str()).ok()
} else { } else {
None // For non-stdlib search paths, check for known third-party modules
Self::try_from_third_party_name(name)
}
}
/// Returns a known module for third-party packages, if applicable.
fn try_from_third_party_name(name: &ModuleName) -> Option<Self> {
match name.as_str() {
"sqlalchemy.sql.selectable" => Some(Self::SqlalchemySqlSelectable),
"sqlalchemy.sql._selectable_constructors" => {
Some(Self::SqlalchemySqlSelectableConstructors)
}
"sqlalchemy.orm.attributes" => Some(Self::SqlalchemyOrmAttributes),
_ => None,
} }
} }
@ -419,6 +441,11 @@ mod tests {
let stdlib_search_path = SearchPath::vendored_stdlib(); let stdlib_search_path = SearchPath::vendored_stdlib();
for module in KnownModule::iter() { for module in KnownModule::iter() {
// Third-party modules aren't available in the vendored stdlib
if module.is_third_party() {
continue;
}
let module_name = module.name(); let module_name = module.name();
assert_eq!( assert_eq!(

View File

@ -4207,6 +4207,9 @@ pub enum KnownClass {
ConstraintSet, ConstraintSet,
GenericContext, GenericContext,
Specialization, Specialization,
// sqlalchemy
SqlalchemySelect,
SqlalchemyInstrumentedAttribute,
} }
impl KnownClass { impl KnownClass {
@ -4315,7 +4318,9 @@ impl KnownClass {
| Self::GenericContext | Self::GenericContext
| Self::Specialization | Self::Specialization
| Self::ProtocolMeta | Self::ProtocolMeta
| Self::TypedDictFallback => Some(Truthiness::Ambiguous), | Self::TypedDictFallback
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => Some(Truthiness::Ambiguous),
Self::Tuple => None, Self::Tuple => None,
} }
@ -4405,7 +4410,9 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType | KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta | KnownClass::ProtocolMeta
| KnownClass::Template | KnownClass::Template
| KnownClass::Path => false, | KnownClass::Path
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
} }
} }
@ -4492,7 +4499,9 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType | KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta | KnownClass::ProtocolMeta
| KnownClass::Template | KnownClass::Template
| KnownClass::Path => false, | KnownClass::Path
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
} }
} }
@ -4578,7 +4587,9 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType | KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta | KnownClass::ProtocolMeta
| KnownClass::Template | KnownClass::Template
| KnownClass::Path => false, | KnownClass::Path
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
} }
} }
@ -4677,7 +4688,9 @@ impl KnownClass {
| Self::ProtocolMeta | Self::ProtocolMeta
| Self::Template | Self::Template
| Self::Path | Self::Path
| Self::Mapping => false, | Self::Mapping
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => false,
} }
} }
@ -4766,7 +4779,9 @@ impl KnownClass {
| KnownClass::ConstraintSet | KnownClass::ConstraintSet
| KnownClass::GenericContext | KnownClass::GenericContext
| KnownClass::Specialization | KnownClass::Specialization
| KnownClass::InitVar => false, | KnownClass::InitVar
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
KnownClass::NamedTupleFallback | KnownClass::TypedDictFallback => true, KnownClass::NamedTupleFallback | KnownClass::TypedDictFallback => true,
} }
} }
@ -4882,6 +4897,8 @@ impl KnownClass {
Self::Template => "Template", Self::Template => "Template",
Self::Path => "Path", Self::Path => "Path",
Self::ProtocolMeta => "_ProtocolMeta", Self::ProtocolMeta => "_ProtocolMeta",
Self::SqlalchemySelect => "Select",
Self::SqlalchemyInstrumentedAttribute => "InstrumentedAttribute",
} }
} }
@ -5203,6 +5220,8 @@ impl KnownClass {
| Self::Specialization => KnownModule::TyExtensions, | Self::Specialization => KnownModule::TyExtensions,
Self::Template => KnownModule::Templatelib, Self::Template => KnownModule::Templatelib,
Self::Path => KnownModule::Pathlib, Self::Path => KnownModule::Pathlib,
Self::SqlalchemySelect => KnownModule::SqlalchemySqlSelectable,
Self::SqlalchemyInstrumentedAttribute => KnownModule::SqlalchemyOrmAttributes,
} }
} }
@ -5291,7 +5310,9 @@ impl KnownClass {
| Self::BuiltinFunctionType | Self::BuiltinFunctionType
| Self::ProtocolMeta | Self::ProtocolMeta
| Self::Template | Self::Template
| Self::Path => Some(false), | Self::Path
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => Some(false),
Self::Tuple => None, Self::Tuple => None,
} }
@ -5383,7 +5404,9 @@ impl KnownClass {
| Self::BuiltinFunctionType | Self::BuiltinFunctionType
| Self::ProtocolMeta | Self::ProtocolMeta
| Self::Template | Self::Template
| Self::Path => false, | Self::Path
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => false,
} }
} }
@ -5489,6 +5512,8 @@ impl KnownClass {
"Template" => &[Self::Template], "Template" => &[Self::Template],
"Path" => &[Self::Path], "Path" => &[Self::Path],
"_ProtocolMeta" => &[Self::ProtocolMeta], "_ProtocolMeta" => &[Self::ProtocolMeta],
"Select" => &[Self::SqlalchemySelect],
"InstrumentedAttribute" => &[Self::SqlalchemyInstrumentedAttribute],
_ => return None, _ => return None,
}; };
@ -5569,7 +5594,9 @@ impl KnownClass {
| Self::Awaitable | Self::Awaitable
| Self::Generator | Self::Generator
| Self::Template | Self::Template
| Self::Path => module == self.canonical_module(db), | Self::Path
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => module == self.canonical_module(db),
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types), Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
Self::SpecialForm Self::SpecialForm
| Self::TypeAliasType | Self::TypeAliasType
@ -5924,6 +5951,10 @@ mod tests {
source: PythonVersionSource::default(), source: PythonVersionSource::default(),
}); });
for class in KnownClass::iter() { for class in KnownClass::iter() {
if class.canonical_module(&db).is_third_party() {
continue;
}
let class_name = class.name(&db); let class_name = class.name(&db);
let class_module = let class_module =
resolve_module_confident(&db, &class.canonical_module(&db).name()).unwrap(); resolve_module_confident(&db, &class.canonical_module(&db).name()).unwrap();
@ -5952,6 +5983,10 @@ mod tests {
}); });
for class in KnownClass::iter() { for class in KnownClass::iter() {
if class.canonical_module(&db).is_third_party() {
continue;
}
// Check the class can be looked up successfully // Check the class can be looked up successfully
class.try_to_class_literal_without_logging(&db).unwrap(); class.try_to_class_literal_without_logging(&db).unwrap();
@ -5977,6 +6012,7 @@ mod tests {
// This makes the test far faster as it minimizes the number of times // This makes the test far faster as it minimizes the number of times
// we need to change the Python version in the loop. // we need to change the Python version in the loop.
let mut classes: Vec<(KnownClass, PythonVersion)> = KnownClass::iter() let mut classes: Vec<(KnownClass, PythonVersion)> = KnownClass::iter()
.filter(|class| !class.canonical_module(&db).is_third_party())
.map(|class| { .map(|class| {
let version_added = match class { let version_added = match class {
KnownClass::Template => PythonVersion::PY314, KnownClass::Template => PythonVersion::PY314,

View File

@ -1353,6 +1353,10 @@ pub enum KnownFunction {
RevealProtocolInterface, RevealProtocolInterface,
/// `ty_extensions.reveal_mro` /// `ty_extensions.reveal_mro`
RevealMro, RevealMro,
/// `sqlalchemy.select`
#[strum(serialize = "select")]
SqlalchemySelect,
} }
impl KnownFunction { impl KnownFunction {
@ -1425,6 +1429,9 @@ impl KnownFunction {
Self::TypeCheckOnly => matches!(module, KnownModule::Typing), Self::TypeCheckOnly => matches!(module, KnownModule::Typing),
Self::NamedTuple => matches!(module, KnownModule::Collections), Self::NamedTuple => matches!(module, KnownModule::Collections),
Self::SqlalchemySelect => {
matches!(module, KnownModule::SqlalchemySqlSelectableConstructors)
}
} }
} }
@ -1896,6 +1903,56 @@ impl KnownFunction {
overload.set_return_type(Type::module_literal(db, file, module)); overload.set_return_type(Type::module_literal(db, file, module));
} }
KnownFunction::SqlalchemySelect => {
// Try to extract types from InstrumentedAttribute[T] arguments.
// If all arguments are InstrumentedAttribute instances, we construct
// Select[tuple[T_1, T_2, ...]] where T_i are the inner types.
//
// We check the class via `class_literal.known(db)` rather than using
// `known_specialization` because the class may be re-exported and not
// directly importable from its canonical module.
let inner_types: Option<Vec<_>> = parameter_types
.iter()
.flatten()
.map(|param_type| {
let Type::NominalInstance(instance) = param_type else {
return None;
};
let class = instance.class(db);
let (class_literal, specialization) = class.class_literal(db);
if class_literal.known(db)
!= Some(KnownClass::SqlalchemyInstrumentedAttribute)
{
return None;
}
specialization?.types(db).first().copied()
})
.collect();
let Some(inner_types) = inner_types else {
// Fall back to whatever we infer from the function signature
return;
};
if inner_types.is_empty() {
return;
}
// Construct Select[tuple[T1, T2, ...]]
// We get the return type's class from the overload rather than looking
// it up via try_to_class_literal, since the class may be re-exported.
let Type::NominalInstance(return_instance) = overload.return_type() else {
return;
};
let select_class = return_instance.class(db).class_literal(db).0;
let tuple_type = Type::heterogeneous_tuple(db, inner_types);
let class_type = select_class.apply_specialization(db, |generic_context| {
generic_context.specialize(db, vec![tuple_type].into())
});
overload.set_return_type(Type::instance(db, class_type));
}
_ => {} _ => {}
} }
} }
@ -1964,6 +2021,8 @@ pub(crate) mod tests {
KnownFunction::ImportModule => KnownModule::ImportLib, KnownFunction::ImportModule => KnownModule::ImportLib,
KnownFunction::NamedTuple => KnownModule::Collections, KnownFunction::NamedTuple => KnownModule::Collections,
KnownFunction::SqlalchemySelect => continue,
}; };
let function_definition = known_module_symbol(&db, module, function_name) let function_definition = known_module_symbol(&db, module, function_name)