mirror of https://github.com/astral-sh/ruff
[ty] Temporary SQLAlchemy special-case
This commit is contained in:
parent
4e67a219bb
commit
352628e986
|
|
@ -106,45 +106,36 @@ reveal_type(admin_users) # revealed: Sequence[User]
|
|||
We can also specify particular columns to select:
|
||||
|
||||
```py
|
||||
reveal_type(User.id) # revealed: InstrumentedAttribute[int]
|
||||
stmt = select(User.id, User.name)
|
||||
# TODO: should be `Select[tuple[int, str]]`
|
||||
reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]]
|
||||
reveal_type(stmt) # revealed: Select[tuple[int, str]]
|
||||
|
||||
ids_and_names = session.execute(stmt).all()
|
||||
# TODO: should be `Sequence[Row[tuple[int, str]]]`
|
||||
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]]
|
||||
reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]]
|
||||
|
||||
for row in session.execute(stmt):
|
||||
# TODO: should be `Row[tuple[int, str]]`
|
||||
reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]]
|
||||
reveal_type(row) # revealed: Row[tuple[int, str]]
|
||||
|
||||
for user_id, name in session.execute(stmt).tuples():
|
||||
# TODO: should be `int`
|
||||
reveal_type(user_id) # revealed: Unknown
|
||||
# TODO: should be `str`
|
||||
reveal_type(name) # revealed: Unknown
|
||||
reveal_type(user_id) # revealed: int
|
||||
reveal_type(name) # revealed: str
|
||||
|
||||
result = session.execute(stmt)
|
||||
row = result.one_or_none()
|
||||
assert row is not None
|
||||
(user_id, name) = row._tuple()
|
||||
# TODO: should be `int`
|
||||
reveal_type(user_id) # revealed: Unknown
|
||||
# TODO: should be `str`
|
||||
reveal_type(name) # revealed: Unknown
|
||||
reveal_type(user_id) # revealed: int
|
||||
reveal_type(name) # revealed: str
|
||||
|
||||
stmt = select(User.id).where(User.name == "Alice")
|
||||
|
||||
# TODO: should be `Select[tuple[int]]`
|
||||
reveal_type(stmt) # revealed: Select[tuple[Unknown]]
|
||||
reveal_type(stmt) # revealed: Select[tuple[int]]
|
||||
|
||||
alice_id = session.scalars(stmt).first()
|
||||
# TODO: should be `int | None`
|
||||
reveal_type(alice_id) # revealed: Unknown | None
|
||||
reveal_type(alice_id) # revealed: int | None
|
||||
|
||||
alice_id = session.scalar(stmt)
|
||||
# TODO: should be `int | None`
|
||||
reveal_type(alice_id) # revealed: Unknown | None
|
||||
reveal_type(alice_id) # revealed: int | None
|
||||
```
|
||||
|
||||
Using the legacy `query` API also works:
|
||||
|
|
@ -203,8 +194,6 @@ async def test_async(session: AsyncSession):
|
|||
stmt = select(User.id, User.name)
|
||||
result = await session.execute(stmt)
|
||||
for user_id, name in result.tuples():
|
||||
# TODO: should be `int`
|
||||
reveal_type(user_id) # revealed: Unknown
|
||||
# TODO: should be `str`
|
||||
reveal_type(name) # revealed: Unknown
|
||||
reveal_type(user_id) # revealed: int
|
||||
reveal_type(name) # revealed: str
|
||||
```
|
||||
|
|
|
|||
|
|
@ -335,6 +335,12 @@ pub enum KnownModule {
|
|||
#[cfg(test)]
|
||||
Uuid,
|
||||
Warnings,
|
||||
#[strum(serialize = "sqlalchemy.sql.selectable")]
|
||||
SqlalchemySqlSelectable,
|
||||
#[strum(serialize = "sqlalchemy.sql._selectable_constructors")]
|
||||
SqlalchemySqlSelectableConstructors,
|
||||
#[strum(serialize = "sqlalchemy.orm.attributes")]
|
||||
SqlalchemyOrmAttributes,
|
||||
}
|
||||
|
||||
impl KnownModule {
|
||||
|
|
@ -363,6 +369,9 @@ impl KnownModule {
|
|||
#[cfg(test)]
|
||||
Self::Uuid => "uuid",
|
||||
Self::Templatelib => "string.templatelib",
|
||||
Self::SqlalchemySqlSelectable => "sqlalchemy.sql.selectable",
|
||||
Self::SqlalchemySqlSelectableConstructors => "sqlalchemy.sql._selectable_constructors",
|
||||
Self::SqlalchemyOrmAttributes => "sqlalchemy.orm.attributes",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -378,7 +387,20 @@ impl KnownModule {
|
|||
if search_path.is_standard_library() {
|
||||
Self::from_str(name.as_str()).ok()
|
||||
} else {
|
||||
None
|
||||
// For non-stdlib search paths, check for known third-party modules
|
||||
Self::try_from_third_party_name(name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a known module for third-party packages, if applicable.
|
||||
fn try_from_third_party_name(name: &ModuleName) -> Option<Self> {
|
||||
match name.as_str() {
|
||||
"sqlalchemy.sql.selectable" => Some(Self::SqlalchemySqlSelectable),
|
||||
"sqlalchemy.sql._selectable_constructors" => {
|
||||
Some(Self::SqlalchemySqlSelectableConstructors)
|
||||
}
|
||||
"sqlalchemy.orm.attributes" => Some(Self::SqlalchemyOrmAttributes),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -419,6 +441,11 @@ mod tests {
|
|||
let stdlib_search_path = SearchPath::vendored_stdlib();
|
||||
|
||||
for module in KnownModule::iter() {
|
||||
// Third-party modules aren't available in the vendored stdlib
|
||||
if module.is_third_party() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let module_name = module.name();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
|||
|
|
@ -4207,6 +4207,9 @@ pub enum KnownClass {
|
|||
ConstraintSet,
|
||||
GenericContext,
|
||||
Specialization,
|
||||
// sqlalchemy
|
||||
SqlalchemySelect,
|
||||
SqlalchemyInstrumentedAttribute,
|
||||
}
|
||||
|
||||
impl KnownClass {
|
||||
|
|
@ -4315,7 +4318,9 @@ impl KnownClass {
|
|||
| Self::GenericContext
|
||||
| Self::Specialization
|
||||
| Self::ProtocolMeta
|
||||
| Self::TypedDictFallback => Some(Truthiness::Ambiguous),
|
||||
| Self::TypedDictFallback
|
||||
| Self::SqlalchemySelect
|
||||
| Self::SqlalchemyInstrumentedAttribute => Some(Truthiness::Ambiguous),
|
||||
|
||||
Self::Tuple => None,
|
||||
}
|
||||
|
|
@ -4405,7 +4410,9 @@ impl KnownClass {
|
|||
| KnownClass::BuiltinFunctionType
|
||||
| KnownClass::ProtocolMeta
|
||||
| KnownClass::Template
|
||||
| KnownClass::Path => false,
|
||||
| KnownClass::Path
|
||||
| KnownClass::SqlalchemySelect
|
||||
| KnownClass::SqlalchemyInstrumentedAttribute => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4492,7 +4499,9 @@ impl KnownClass {
|
|||
| KnownClass::BuiltinFunctionType
|
||||
| KnownClass::ProtocolMeta
|
||||
| KnownClass::Template
|
||||
| KnownClass::Path => false,
|
||||
| KnownClass::Path
|
||||
| KnownClass::SqlalchemySelect
|
||||
| KnownClass::SqlalchemyInstrumentedAttribute => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4578,7 +4587,9 @@ impl KnownClass {
|
|||
| KnownClass::BuiltinFunctionType
|
||||
| KnownClass::ProtocolMeta
|
||||
| KnownClass::Template
|
||||
| KnownClass::Path => false,
|
||||
| KnownClass::Path
|
||||
| KnownClass::SqlalchemySelect
|
||||
| KnownClass::SqlalchemyInstrumentedAttribute => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4677,7 +4688,9 @@ impl KnownClass {
|
|||
| Self::ProtocolMeta
|
||||
| Self::Template
|
||||
| Self::Path
|
||||
| Self::Mapping => false,
|
||||
| Self::Mapping
|
||||
| Self::SqlalchemySelect
|
||||
| Self::SqlalchemyInstrumentedAttribute => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -4766,7 +4779,9 @@ impl KnownClass {
|
|||
| KnownClass::ConstraintSet
|
||||
| KnownClass::GenericContext
|
||||
| KnownClass::Specialization
|
||||
| KnownClass::InitVar => false,
|
||||
| KnownClass::InitVar
|
||||
| KnownClass::SqlalchemySelect
|
||||
| KnownClass::SqlalchemyInstrumentedAttribute => false,
|
||||
KnownClass::NamedTupleFallback | KnownClass::TypedDictFallback => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -4882,6 +4897,8 @@ impl KnownClass {
|
|||
Self::Template => "Template",
|
||||
Self::Path => "Path",
|
||||
Self::ProtocolMeta => "_ProtocolMeta",
|
||||
Self::SqlalchemySelect => "Select",
|
||||
Self::SqlalchemyInstrumentedAttribute => "InstrumentedAttribute",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5203,6 +5220,8 @@ impl KnownClass {
|
|||
| Self::Specialization => KnownModule::TyExtensions,
|
||||
Self::Template => KnownModule::Templatelib,
|
||||
Self::Path => KnownModule::Pathlib,
|
||||
Self::SqlalchemySelect => KnownModule::SqlalchemySqlSelectable,
|
||||
Self::SqlalchemyInstrumentedAttribute => KnownModule::SqlalchemyOrmAttributes,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5291,7 +5310,9 @@ impl KnownClass {
|
|||
| Self::BuiltinFunctionType
|
||||
| Self::ProtocolMeta
|
||||
| Self::Template
|
||||
| Self::Path => Some(false),
|
||||
| Self::Path
|
||||
| Self::SqlalchemySelect
|
||||
| Self::SqlalchemyInstrumentedAttribute => Some(false),
|
||||
|
||||
Self::Tuple => None,
|
||||
}
|
||||
|
|
@ -5383,7 +5404,9 @@ impl KnownClass {
|
|||
| Self::BuiltinFunctionType
|
||||
| Self::ProtocolMeta
|
||||
| Self::Template
|
||||
| Self::Path => false,
|
||||
| Self::Path
|
||||
| Self::SqlalchemySelect
|
||||
| Self::SqlalchemyInstrumentedAttribute => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5489,6 +5512,8 @@ impl KnownClass {
|
|||
"Template" => &[Self::Template],
|
||||
"Path" => &[Self::Path],
|
||||
"_ProtocolMeta" => &[Self::ProtocolMeta],
|
||||
"Select" => &[Self::SqlalchemySelect],
|
||||
"InstrumentedAttribute" => &[Self::SqlalchemyInstrumentedAttribute],
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
|
|
@ -5569,7 +5594,9 @@ impl KnownClass {
|
|||
| Self::Awaitable
|
||||
| Self::Generator
|
||||
| Self::Template
|
||||
| Self::Path => module == self.canonical_module(db),
|
||||
| Self::Path
|
||||
| Self::SqlalchemySelect
|
||||
| Self::SqlalchemyInstrumentedAttribute => module == self.canonical_module(db),
|
||||
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
|
||||
Self::SpecialForm
|
||||
| Self::TypeAliasType
|
||||
|
|
@ -5924,6 +5951,10 @@ mod tests {
|
|||
source: PythonVersionSource::default(),
|
||||
});
|
||||
for class in KnownClass::iter() {
|
||||
if class.canonical_module(&db).is_third_party() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let class_name = class.name(&db);
|
||||
let class_module =
|
||||
resolve_module_confident(&db, &class.canonical_module(&db).name()).unwrap();
|
||||
|
|
@ -5952,6 +5983,10 @@ mod tests {
|
|||
});
|
||||
|
||||
for class in KnownClass::iter() {
|
||||
if class.canonical_module(&db).is_third_party() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the class can be looked up successfully
|
||||
class.try_to_class_literal_without_logging(&db).unwrap();
|
||||
|
||||
|
|
@ -5977,6 +6012,7 @@ mod tests {
|
|||
// This makes the test far faster as it minimizes the number of times
|
||||
// we need to change the Python version in the loop.
|
||||
let mut classes: Vec<(KnownClass, PythonVersion)> = KnownClass::iter()
|
||||
.filter(|class| !class.canonical_module(&db).is_third_party())
|
||||
.map(|class| {
|
||||
let version_added = match class {
|
||||
KnownClass::Template => PythonVersion::PY314,
|
||||
|
|
|
|||
|
|
@ -1353,6 +1353,10 @@ pub enum KnownFunction {
|
|||
RevealProtocolInterface,
|
||||
/// `ty_extensions.reveal_mro`
|
||||
RevealMro,
|
||||
|
||||
/// `sqlalchemy.select`
|
||||
#[strum(serialize = "select")]
|
||||
SqlalchemySelect,
|
||||
}
|
||||
|
||||
impl KnownFunction {
|
||||
|
|
@ -1425,6 +1429,9 @@ impl KnownFunction {
|
|||
|
||||
Self::TypeCheckOnly => matches!(module, KnownModule::Typing),
|
||||
Self::NamedTuple => matches!(module, KnownModule::Collections),
|
||||
Self::SqlalchemySelect => {
|
||||
matches!(module, KnownModule::SqlalchemySqlSelectableConstructors)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1896,6 +1903,56 @@ impl KnownFunction {
|
|||
|
||||
overload.set_return_type(Type::module_literal(db, file, module));
|
||||
}
|
||||
|
||||
KnownFunction::SqlalchemySelect => {
|
||||
// Try to extract types from InstrumentedAttribute[T] arguments.
|
||||
// If all arguments are InstrumentedAttribute instances, we construct
|
||||
// Select[tuple[T_1, T_2, ...]] where T_i are the inner types.
|
||||
//
|
||||
// We check the class via `class_literal.known(db)` rather than using
|
||||
// `known_specialization` because the class may be re-exported and not
|
||||
// directly importable from its canonical module.
|
||||
let inner_types: Option<Vec<_>> = parameter_types
|
||||
.iter()
|
||||
.flatten()
|
||||
.map(|param_type| {
|
||||
let Type::NominalInstance(instance) = param_type else {
|
||||
return None;
|
||||
};
|
||||
let class = instance.class(db);
|
||||
let (class_literal, specialization) = class.class_literal(db);
|
||||
if class_literal.known(db)
|
||||
!= Some(KnownClass::SqlalchemyInstrumentedAttribute)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
specialization?.types(db).first().copied()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let Some(inner_types) = inner_types else {
|
||||
// Fall back to whatever we infer from the function signature
|
||||
return;
|
||||
};
|
||||
|
||||
if inner_types.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Construct Select[tuple[T1, T2, ...]]
|
||||
// We get the return type's class from the overload rather than looking
|
||||
// it up via try_to_class_literal, since the class may be re-exported.
|
||||
let Type::NominalInstance(return_instance) = overload.return_type() else {
|
||||
return;
|
||||
};
|
||||
let select_class = return_instance.class(db).class_literal(db).0;
|
||||
let tuple_type = Type::heterogeneous_tuple(db, inner_types);
|
||||
let class_type = select_class.apply_specialization(db, |generic_context| {
|
||||
generic_context.specialize(db, vec![tuple_type].into())
|
||||
});
|
||||
overload.set_return_type(Type::instance(db, class_type));
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
|
@ -1964,6 +2021,8 @@ pub(crate) mod tests {
|
|||
|
||||
KnownFunction::ImportModule => KnownModule::ImportLib,
|
||||
KnownFunction::NamedTuple => KnownModule::Collections,
|
||||
|
||||
KnownFunction::SqlalchemySelect => continue,
|
||||
};
|
||||
|
||||
let function_definition = known_module_symbol(&db, module, function_name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue