[ty] Temporary SQLAlchemy special-case

This commit is contained in:
David Peter 2025-12-09 10:34:09 +01:00
parent 4e67a219bb
commit 352628e986
4 changed files with 145 additions and 34 deletions

View File

@ -106,45 +106,36 @@ reveal_type(admin_users) # revealed: Sequence[User]
We can also specify particular columns to select:
```py
reveal_type(User.id) # revealed: InstrumentedAttribute[int]
stmt = select(User.id, User.name)
# TODO: should be `Select[tuple[int, str]]`
reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]]
reveal_type(stmt) # revealed: Select[tuple[int, str]]
ids_and_names = session.execute(stmt).all()
# TODO: should be `Sequence[Row[tuple[int, str]]]`
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]]
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]]
for row in session.execute(stmt):
# TODO: should be `Row[tuple[int, str]]`
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
reveal_type(row) # revealed: Row[tuple[int, str]]
for user_id, name in session.execute(stmt).tuples():
# TODO: should be `int`
reveal_type(user_id) # revealed: Unknown
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
reveal_type(user_id) # revealed: int
reveal_type(name) # revealed: str
result = session.execute(stmt)
row = result.one_or_none()
assert row is not None
(user_id, name) = row._tuple()
# TODO: should be `int`
reveal_type(user_id) # revealed: Unknown
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
reveal_type(user_id) # revealed: int
reveal_type(name) # revealed: str
stmt = select(User.id).where(User.name == "Alice")
# TODO: should be `Select[tuple[int]]`
reveal_type(stmt) # revealed: Select[tuple[Unknown]]
reveal_type(stmt) # revealed: Select[tuple[int]]
alice_id = session.scalars(stmt).first()
# TODO: should be `int | None`
reveal_type(alice_id) # revealed: Unknown | None
reveal_type(alice_id) # revealed: int | None
alice_id = session.scalar(stmt)
# TODO: should be `int | None`
reveal_type(alice_id) # revealed: Unknown | None
reveal_type(alice_id) # revealed: int | None
```
Using the legacy `query` API also works:
@ -203,8 +194,6 @@ async def test_async(session: AsyncSession):
stmt = select(User.id, User.name)
result = await session.execute(stmt)
for user_id, name in result.tuples():
# TODO: should be `int`
reveal_type(user_id) # revealed: Unknown
# TODO: should be `str`
reveal_type(name) # revealed: Unknown
reveal_type(user_id) # revealed: int
reveal_type(name) # revealed: str
```

View File

@ -335,6 +335,12 @@ pub enum KnownModule {
#[cfg(test)]
Uuid,
Warnings,
#[strum(serialize = "sqlalchemy.sql.selectable")]
SqlalchemySqlSelectable,
#[strum(serialize = "sqlalchemy.sql._selectable_constructors")]
SqlalchemySqlSelectableConstructors,
#[strum(serialize = "sqlalchemy.orm.attributes")]
SqlalchemyOrmAttributes,
}
impl KnownModule {
@ -363,6 +369,9 @@ impl KnownModule {
#[cfg(test)]
Self::Uuid => "uuid",
Self::Templatelib => "string.templatelib",
Self::SqlalchemySqlSelectable => "sqlalchemy.sql.selectable",
Self::SqlalchemySqlSelectableConstructors => "sqlalchemy.sql._selectable_constructors",
Self::SqlalchemyOrmAttributes => "sqlalchemy.orm.attributes",
}
}
@ -378,7 +387,20 @@ impl KnownModule {
if search_path.is_standard_library() {
Self::from_str(name.as_str()).ok()
} else {
None
// For non-stdlib search paths, check for known third-party modules
Self::try_from_third_party_name(name)
}
}
/// Returns a known module for third-party packages, if applicable.
fn try_from_third_party_name(name: &ModuleName) -> Option<Self> {
match name.as_str() {
"sqlalchemy.sql.selectable" => Some(Self::SqlalchemySqlSelectable),
"sqlalchemy.sql._selectable_constructors" => {
Some(Self::SqlalchemySqlSelectableConstructors)
}
"sqlalchemy.orm.attributes" => Some(Self::SqlalchemyOrmAttributes),
_ => None,
}
}
@ -419,6 +441,11 @@ mod tests {
let stdlib_search_path = SearchPath::vendored_stdlib();
for module in KnownModule::iter() {
// Third-party modules aren't available in the vendored stdlib
if module.is_third_party() {
continue;
}
let module_name = module.name();
assert_eq!(

View File

@ -4207,6 +4207,9 @@ pub enum KnownClass {
ConstraintSet,
GenericContext,
Specialization,
// sqlalchemy
SqlalchemySelect,
SqlalchemyInstrumentedAttribute,
}
impl KnownClass {
@ -4315,7 +4318,9 @@ impl KnownClass {
| Self::GenericContext
| Self::Specialization
| Self::ProtocolMeta
| Self::TypedDictFallback => Some(Truthiness::Ambiguous),
| Self::TypedDictFallback
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => Some(Truthiness::Ambiguous),
Self::Tuple => None,
}
@ -4405,7 +4410,9 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta
| KnownClass::Template
| KnownClass::Path => false,
| KnownClass::Path
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
}
}
@ -4492,7 +4499,9 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta
| KnownClass::Template
| KnownClass::Path => false,
| KnownClass::Path
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
}
}
@ -4578,7 +4587,9 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta
| KnownClass::Template
| KnownClass::Path => false,
| KnownClass::Path
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
}
}
@ -4677,7 +4688,9 @@ impl KnownClass {
| Self::ProtocolMeta
| Self::Template
| Self::Path
| Self::Mapping => false,
| Self::Mapping
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => false,
}
}
@ -4766,7 +4779,9 @@ impl KnownClass {
| KnownClass::ConstraintSet
| KnownClass::GenericContext
| KnownClass::Specialization
| KnownClass::InitVar => false,
| KnownClass::InitVar
| KnownClass::SqlalchemySelect
| KnownClass::SqlalchemyInstrumentedAttribute => false,
KnownClass::NamedTupleFallback | KnownClass::TypedDictFallback => true,
}
}
@ -4882,6 +4897,8 @@ impl KnownClass {
Self::Template => "Template",
Self::Path => "Path",
Self::ProtocolMeta => "_ProtocolMeta",
Self::SqlalchemySelect => "Select",
Self::SqlalchemyInstrumentedAttribute => "InstrumentedAttribute",
}
}
@ -5203,6 +5220,8 @@ impl KnownClass {
| Self::Specialization => KnownModule::TyExtensions,
Self::Template => KnownModule::Templatelib,
Self::Path => KnownModule::Pathlib,
Self::SqlalchemySelect => KnownModule::SqlalchemySqlSelectable,
Self::SqlalchemyInstrumentedAttribute => KnownModule::SqlalchemyOrmAttributes,
}
}
@ -5291,7 +5310,9 @@ impl KnownClass {
| Self::BuiltinFunctionType
| Self::ProtocolMeta
| Self::Template
| Self::Path => Some(false),
| Self::Path
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => Some(false),
Self::Tuple => None,
}
@ -5383,7 +5404,9 @@ impl KnownClass {
| Self::BuiltinFunctionType
| Self::ProtocolMeta
| Self::Template
| Self::Path => false,
| Self::Path
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => false,
}
}
@ -5489,6 +5512,8 @@ impl KnownClass {
"Template" => &[Self::Template],
"Path" => &[Self::Path],
"_ProtocolMeta" => &[Self::ProtocolMeta],
"Select" => &[Self::SqlalchemySelect],
"InstrumentedAttribute" => &[Self::SqlalchemyInstrumentedAttribute],
_ => return None,
};
@ -5569,7 +5594,9 @@ impl KnownClass {
| Self::Awaitable
| Self::Generator
| Self::Template
| Self::Path => module == self.canonical_module(db),
| Self::Path
| Self::SqlalchemySelect
| Self::SqlalchemyInstrumentedAttribute => module == self.canonical_module(db),
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
Self::SpecialForm
| Self::TypeAliasType
@ -5924,6 +5951,10 @@ mod tests {
source: PythonVersionSource::default(),
});
for class in KnownClass::iter() {
if class.canonical_module(&db).is_third_party() {
continue;
}
let class_name = class.name(&db);
let class_module =
resolve_module_confident(&db, &class.canonical_module(&db).name()).unwrap();
@ -5952,6 +5983,10 @@ mod tests {
});
for class in KnownClass::iter() {
if class.canonical_module(&db).is_third_party() {
continue;
}
// Check the class can be looked up successfully
class.try_to_class_literal_without_logging(&db).unwrap();
@ -5977,6 +6012,7 @@ mod tests {
// This makes the test far faster as it minimizes the number of times
// we need to change the Python version in the loop.
let mut classes: Vec<(KnownClass, PythonVersion)> = KnownClass::iter()
.filter(|class| !class.canonical_module(&db).is_third_party())
.map(|class| {
let version_added = match class {
KnownClass::Template => PythonVersion::PY314,

View File

@ -1353,6 +1353,10 @@ pub enum KnownFunction {
RevealProtocolInterface,
/// `ty_extensions.reveal_mro`
RevealMro,
/// `sqlalchemy.select`
#[strum(serialize = "select")]
SqlalchemySelect,
}
impl KnownFunction {
@ -1425,6 +1429,9 @@ impl KnownFunction {
Self::TypeCheckOnly => matches!(module, KnownModule::Typing),
Self::NamedTuple => matches!(module, KnownModule::Collections),
Self::SqlalchemySelect => {
matches!(module, KnownModule::SqlalchemySqlSelectableConstructors)
}
}
}
@ -1896,6 +1903,56 @@ impl KnownFunction {
overload.set_return_type(Type::module_literal(db, file, module));
}
KnownFunction::SqlalchemySelect => {
// Try to extract types from InstrumentedAttribute[T] arguments.
// If all arguments are InstrumentedAttribute instances, we construct
// Select[tuple[T_1, T_2, ...]] where T_i are the inner types.
//
// We check the class via `class_literal.known(db)` rather than using
// `known_specialization` because the class may be re-exported and not
// directly importable from its canonical module.
let inner_types: Option<Vec<_>> = parameter_types
.iter()
.flatten()
.map(|param_type| {
let Type::NominalInstance(instance) = param_type else {
return None;
};
let class = instance.class(db);
let (class_literal, specialization) = class.class_literal(db);
if class_literal.known(db)
!= Some(KnownClass::SqlalchemyInstrumentedAttribute)
{
return None;
}
specialization?.types(db).first().copied()
})
.collect();
let Some(inner_types) = inner_types else {
// Fall back to whatever we infer from the function signature
return;
};
if inner_types.is_empty() {
return;
}
// Construct Select[tuple[T1, T2, ...]]
// We get the return type's class from the overload rather than looking
// it up via try_to_class_literal, since the class may be re-exported.
let Type::NominalInstance(return_instance) = overload.return_type() else {
return;
};
let select_class = return_instance.class(db).class_literal(db).0;
let tuple_type = Type::heterogeneous_tuple(db, inner_types);
let class_type = select_class.apply_specialization(db, |generic_context| {
generic_context.specialize(db, vec![tuple_type].into())
});
overload.set_return_type(Type::instance(db, class_type));
}
_ => {}
}
}
@ -1964,6 +2021,8 @@ pub(crate) mod tests {
KnownFunction::ImportModule => KnownModule::ImportLib,
KnownFunction::NamedTuple => KnownModule::Collections,
KnownFunction::SqlalchemySelect => continue,
};
let function_definition = known_module_symbol(&db, module, function_name)