diff --git a/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md b/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md index 61e6668de1..dbabf687ab 100644 --- a/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md +++ b/crates/ty_python_semantic/resources/mdtest/external/sqlalchemy.md @@ -106,45 +106,36 @@ reveal_type(admin_users) # revealed: Sequence[User] We can also specify particular columns to select: ```py +reveal_type(User.id) # revealed: InstrumentedAttribute[int] stmt = select(User.id, User.name) -# TODO: should be `Select[tuple[int, str]]` -reveal_type(stmt) # revealed: Select[tuple[Unknown, Unknown]] +reveal_type(stmt) # revealed: Select[tuple[int, str]] ids_and_names = session.execute(stmt).all() -# TODO: should be `Sequence[Row[tuple[int, str]]]` -reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[Unknown, Unknown]]] +reveal_type(ids_and_names) # revealed: Sequence[Row[tuple[int, str]]] for row in session.execute(stmt): - # TODO: should be `Row[tuple[int, str]]` - reveal_type(row) # revealed: Row[tuple[Unknown, Unknown]] + reveal_type(row) # revealed: Row[tuple[int, str]] for user_id, name in session.execute(stmt).tuples(): - # TODO: should be `int` - reveal_type(user_id) # revealed: Unknown - # TODO: should be `str` - reveal_type(name) # revealed: Unknown + reveal_type(user_id) # revealed: int + reveal_type(name) # revealed: str result = session.execute(stmt) row = result.one_or_none() assert row is not None (user_id, name) = row._tuple() -# TODO: should be `int` -reveal_type(user_id) # revealed: Unknown -# TODO: should be `str` -reveal_type(name) # revealed: Unknown +reveal_type(user_id) # revealed: int +reveal_type(name) # revealed: str stmt = select(User.id).where(User.name == "Alice") -# TODO: should be `Select[tuple[int]]` -reveal_type(stmt) # revealed: Select[tuple[Unknown]] +reveal_type(stmt) # revealed: Select[tuple[int]] alice_id = session.scalars(stmt).first() -# TODO: should be `int | None` -reveal_type(alice_id) # revealed: Unknown | None +reveal_type(alice_id) # revealed: int | None alice_id = session.scalar(stmt) -# TODO: should be `int | None` -reveal_type(alice_id) # revealed: Unknown | None +reveal_type(alice_id) # revealed: int | None ``` Using the legacy `query` API also works: @@ -203,8 +194,6 @@ async def test_async(session: AsyncSession): stmt = select(User.id, User.name) result = await session.execute(stmt) for user_id, name in result.tuples(): - # TODO: should be `int` - reveal_type(user_id) # revealed: Unknown - # TODO: should be `str` - reveal_type(name) # revealed: Unknown + reveal_type(user_id) # revealed: int + reveal_type(name) # revealed: str ``` diff --git a/crates/ty_python_semantic/src/module_resolver/module.rs b/crates/ty_python_semantic/src/module_resolver/module.rs index 118c2aff45..a98af53f7f 100644 --- a/crates/ty_python_semantic/src/module_resolver/module.rs +++ b/crates/ty_python_semantic/src/module_resolver/module.rs @@ -335,6 +335,12 @@ pub enum KnownModule { #[cfg(test)] Uuid, Warnings, + #[strum(serialize = "sqlalchemy.sql.selectable")] + SqlalchemySqlSelectable, + #[strum(serialize = "sqlalchemy.sql._selectable_constructors")] + SqlalchemySqlSelectableConstructors, + #[strum(serialize = "sqlalchemy.orm.attributes")] + SqlalchemyOrmAttributes, } impl KnownModule { @@ -363,6 +369,9 @@ impl KnownModule { #[cfg(test)] Self::Uuid => "uuid", Self::Templatelib => "string.templatelib", + Self::SqlalchemySqlSelectable => "sqlalchemy.sql.selectable", + Self::SqlalchemySqlSelectableConstructors => "sqlalchemy.sql._selectable_constructors", + Self::SqlalchemyOrmAttributes => "sqlalchemy.orm.attributes", } } @@ -378,7 +387,20 @@ impl KnownModule { if search_path.is_standard_library() { Self::from_str(name.as_str()).ok() } else { - None + // For non-stdlib search paths, check for known third-party modules + Self::try_from_third_party_name(name) + } + } + + /// Returns a known module for third-party packages, if applicable. + fn try_from_third_party_name(name: &ModuleName) -> Option { + match name.as_str() { + "sqlalchemy.sql.selectable" => Some(Self::SqlalchemySqlSelectable), + "sqlalchemy.sql._selectable_constructors" => { + Some(Self::SqlalchemySqlSelectableConstructors) + } + "sqlalchemy.orm.attributes" => Some(Self::SqlalchemyOrmAttributes), + _ => None, } } @@ -419,6 +441,11 @@ mod tests { let stdlib_search_path = SearchPath::vendored_stdlib(); for module in KnownModule::iter() { + // Third-party modules aren't available in the vendored stdlib + if module.is_third_party() { + continue; + } + let module_name = module.name(); assert_eq!( diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 855e8922a0..5488cfc821 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -4207,6 +4207,9 @@ pub enum KnownClass { ConstraintSet, GenericContext, Specialization, + // sqlalchemy + SqlalchemySelect, + SqlalchemyInstrumentedAttribute, } impl KnownClass { @@ -4315,7 +4318,9 @@ impl KnownClass { | Self::GenericContext | Self::Specialization | Self::ProtocolMeta - | Self::TypedDictFallback => Some(Truthiness::Ambiguous), + | Self::TypedDictFallback + | Self::SqlalchemySelect + | Self::SqlalchemyInstrumentedAttribute => Some(Truthiness::Ambiguous), Self::Tuple => None, } @@ -4405,7 +4410,9 @@ impl KnownClass { | KnownClass::BuiltinFunctionType | KnownClass::ProtocolMeta | KnownClass::Template - | KnownClass::Path => false, + | KnownClass::Path + | KnownClass::SqlalchemySelect + | KnownClass::SqlalchemyInstrumentedAttribute => false, } } @@ -4492,7 +4499,9 @@ impl KnownClass { | KnownClass::BuiltinFunctionType | KnownClass::ProtocolMeta | KnownClass::Template - | KnownClass::Path => false, + | KnownClass::Path + | KnownClass::SqlalchemySelect + | KnownClass::SqlalchemyInstrumentedAttribute => false, } } @@ -4578,7 +4587,9 @@ impl KnownClass { | KnownClass::BuiltinFunctionType | KnownClass::ProtocolMeta | KnownClass::Template - | KnownClass::Path => false, + | KnownClass::Path + | KnownClass::SqlalchemySelect + | KnownClass::SqlalchemyInstrumentedAttribute => false, } } @@ -4677,7 +4688,9 @@ impl KnownClass { | Self::ProtocolMeta | Self::Template | Self::Path - | Self::Mapping => false, + | Self::Mapping + | Self::SqlalchemySelect + | Self::SqlalchemyInstrumentedAttribute => false, } } @@ -4766,7 +4779,9 @@ impl KnownClass { | KnownClass::ConstraintSet | KnownClass::GenericContext | KnownClass::Specialization - | KnownClass::InitVar => false, + | KnownClass::InitVar + | KnownClass::SqlalchemySelect + | KnownClass::SqlalchemyInstrumentedAttribute => false, KnownClass::NamedTupleFallback | KnownClass::TypedDictFallback => true, } } @@ -4882,6 +4897,8 @@ impl KnownClass { Self::Template => "Template", Self::Path => "Path", Self::ProtocolMeta => "_ProtocolMeta", + Self::SqlalchemySelect => "Select", + Self::SqlalchemyInstrumentedAttribute => "InstrumentedAttribute", } } @@ -5203,6 +5220,8 @@ impl KnownClass { | Self::Specialization => KnownModule::TyExtensions, Self::Template => KnownModule::Templatelib, Self::Path => KnownModule::Pathlib, + Self::SqlalchemySelect => KnownModule::SqlalchemySqlSelectable, + Self::SqlalchemyInstrumentedAttribute => KnownModule::SqlalchemyOrmAttributes, } } @@ -5291,7 +5310,9 @@ impl KnownClass { | Self::BuiltinFunctionType | Self::ProtocolMeta | Self::Template - | Self::Path => Some(false), + | Self::Path + | Self::SqlalchemySelect + | Self::SqlalchemyInstrumentedAttribute => Some(false), Self::Tuple => None, } @@ -5383,7 +5404,9 @@ impl KnownClass { | Self::BuiltinFunctionType | Self::ProtocolMeta | Self::Template - | Self::Path => false, + | Self::Path + | Self::SqlalchemySelect + | Self::SqlalchemyInstrumentedAttribute => false, } } @@ -5489,6 +5512,8 @@ impl KnownClass { "Template" => &[Self::Template], "Path" => &[Self::Path], "_ProtocolMeta" => &[Self::ProtocolMeta], + "Select" => &[Self::SqlalchemySelect], + "InstrumentedAttribute" => &[Self::SqlalchemyInstrumentedAttribute], _ => return None, }; @@ -5569,7 +5594,9 @@ impl KnownClass { | Self::Awaitable | Self::Generator | Self::Template - | Self::Path => module == self.canonical_module(db), + | Self::Path + | Self::SqlalchemySelect + | Self::SqlalchemyInstrumentedAttribute => module == self.canonical_module(db), Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types), Self::SpecialForm | Self::TypeAliasType @@ -5924,6 +5951,10 @@ mod tests { source: PythonVersionSource::default(), }); for class in KnownClass::iter() { + if class.canonical_module(&db).is_third_party() { + continue; + } + let class_name = class.name(&db); let class_module = resolve_module_confident(&db, &class.canonical_module(&db).name()).unwrap(); @@ -5952,6 +5983,10 @@ mod tests { }); for class in KnownClass::iter() { + if class.canonical_module(&db).is_third_party() { + continue; + } + // Check the class can be looked up successfully class.try_to_class_literal_without_logging(&db).unwrap(); @@ -5977,6 +6012,7 @@ mod tests { // This makes the test far faster as it minimizes the number of times // we need to change the Python version in the loop. let mut classes: Vec<(KnownClass, PythonVersion)> = KnownClass::iter() + .filter(|class| !class.canonical_module(&db).is_third_party()) .map(|class| { let version_added = match class { KnownClass::Template => PythonVersion::PY314, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 6b6c798615..d6f3a7c2f4 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1353,6 +1353,10 @@ pub enum KnownFunction { RevealProtocolInterface, /// `ty_extensions.reveal_mro` RevealMro, + + /// `sqlalchemy.select` + #[strum(serialize = "select")] + SqlalchemySelect, } impl KnownFunction { @@ -1425,6 +1429,9 @@ impl KnownFunction { Self::TypeCheckOnly => matches!(module, KnownModule::Typing), Self::NamedTuple => matches!(module, KnownModule::Collections), + Self::SqlalchemySelect => { + matches!(module, KnownModule::SqlalchemySqlSelectableConstructors) + } } } @@ -1896,6 +1903,56 @@ impl KnownFunction { overload.set_return_type(Type::module_literal(db, file, module)); } + + KnownFunction::SqlalchemySelect => { + // Try to extract types from InstrumentedAttribute[T] arguments. + // If all arguments are InstrumentedAttribute instances, we construct + // Select[tuple[T_1, T_2, ...]] where T_i are the inner types. + // + // We check the class via `class_literal.known(db)` rather than using + // `known_specialization` because the class may be re-exported and not + // directly importable from its canonical module. + let inner_types: Option> = parameter_types + .iter() + .flatten() + .map(|param_type| { + let Type::NominalInstance(instance) = param_type else { + return None; + }; + let class = instance.class(db); + let (class_literal, specialization) = class.class_literal(db); + if class_literal.known(db) + != Some(KnownClass::SqlalchemyInstrumentedAttribute) + { + return None; + } + specialization?.types(db).first().copied() + }) + .collect(); + + let Some(inner_types) = inner_types else { + // Fall back to whatever we infer from the function signature + return; + }; + + if inner_types.is_empty() { + return; + } + + // Construct Select[tuple[T1, T2, ...]] + // We get the return type's class from the overload rather than looking + // it up via try_to_class_literal, since the class may be re-exported. + let Type::NominalInstance(return_instance) = overload.return_type() else { + return; + }; + let select_class = return_instance.class(db).class_literal(db).0; + let tuple_type = Type::heterogeneous_tuple(db, inner_types); + let class_type = select_class.apply_specialization(db, |generic_context| { + generic_context.specialize(db, vec![tuple_type].into()) + }); + overload.set_return_type(Type::instance(db, class_type)); + } + _ => {} } } @@ -1964,6 +2021,8 @@ pub(crate) mod tests { KnownFunction::ImportModule => KnownModule::ImportLib, KnownFunction::NamedTuple => KnownModule::Collections, + + KnownFunction::SqlalchemySelect => continue, }; let function_definition = known_module_symbol(&db, module, function_name)