[ty] Explicit control over cycle entrypoint

This commit is contained in:
David Peter 2025-09-02 16:00:52 +02:00
parent bbfcf6e111
commit 8e52fa2fab
7 changed files with 330 additions and 304 deletions

View File

@ -182,7 +182,7 @@ impl<'db> DunderAllNamesCollector<'db> {
/// ///
/// This function panics if `expr` was not marked as a standalone expression during semantic indexing. /// This function panics if `expr` was not marked as a standalone expression during semantic indexing.
fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> {
infer_expression_types(self.db, self.index.expression(expr)).expression_type(expr) infer_expression_types(self.db, self.index.expression(expr), false).expression_type(expr)
} }
/// Evaluate the given expression and return its truthiness. /// Evaluate the given expression and return its truthiness.

View File

@ -328,10 +328,10 @@ fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type
fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> { fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> {
match kind { match kind {
PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton), PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton),
PatternPredicateKind::Value(value) => infer_expression_type(db, *value), PatternPredicateKind::Value(value) => infer_expression_type(db, *value, false),
PatternPredicateKind::Class(class_expr, kind) => { PatternPredicateKind::Class(class_expr, kind) => {
if kind.is_irrefutable() { if kind.is_irrefutable() {
infer_expression_type(db, *class_expr) infer_expression_type(db, *class_expr, false)
.to_instance(db) .to_instance(db)
.unwrap_or(Type::Never) .unwrap_or(Type::Never)
} else { } else {
@ -718,7 +718,7 @@ impl ReachabilityConstraints {
) -> Truthiness { ) -> Truthiness {
match predicate_kind { match predicate_kind {
PatternPredicateKind::Value(value) => { PatternPredicateKind::Value(value) => {
let value_ty = infer_expression_type(db, *value); let value_ty = infer_expression_type(db, *value, false);
if subject_ty.is_single_valued(db) { if subject_ty.is_single_valued(db) {
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty)) Truthiness::from(subject_ty.is_equivalent_to(db, value_ty))
@ -769,7 +769,7 @@ impl ReachabilityConstraints {
truthiness truthiness
} }
PatternPredicateKind::Class(class_expr, kind) => { PatternPredicateKind::Class(class_expr, kind) => {
let class_ty = infer_expression_type(db, *class_expr).to_instance(db); let class_ty = infer_expression_type(db, *class_expr, false).to_instance(db);
class_ty.map_or(Truthiness::Ambiguous, |class_ty| { class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
if subject_ty.is_subtype_of(db, class_ty) { if subject_ty.is_subtype_of(db, class_ty) {
@ -797,7 +797,7 @@ impl ReachabilityConstraints {
} }
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
let subject_ty = infer_expression_type(db, predicate.subject(db)); let subject_ty = infer_expression_type(db, predicate.subject(db), false);
let narrowed_subject_ty = IntersectionBuilder::new(db) let narrowed_subject_ty = IntersectionBuilder::new(db)
.add_positive(subject_ty) .add_positive(subject_ty)
@ -837,7 +837,7 @@ impl ReachabilityConstraints {
// selection algorithm). // selection algorithm).
// Avoiding this on the happy-path is important because these constraints can be // Avoiding this on the happy-path is important because these constraints can be
// very large in number, since we add them on all statement level function calls. // very large in number, since we add them on all statement level function calls.
let ty = infer_expression_type(db, callable); let ty = infer_expression_type(db, callable, false);
// Short-circuit for well known types that are known not to return `Never` when called. // Short-circuit for well known types that are known not to return `Never` when called.
// Without the short-circuit, we've seen that threads keep blocking each other // Without the short-circuit, we've seen that threads keep blocking each other
@ -875,7 +875,7 @@ impl ReachabilityConstraints {
} else if all_overloads_return_never { } else if all_overloads_return_never {
Truthiness::AlwaysTrue Truthiness::AlwaysTrue
} else { } else {
let call_expr_ty = infer_expression_type(db, call_expr); let call_expr_ty = infer_expression_type(db, call_expr, false);
if call_expr_ty.is_equivalent_to(db, Type::Never) { if call_expr_ty.is_equivalent_to(db, Type::Never) {
Truthiness::AlwaysTrue Truthiness::AlwaysTrue
} else { } else {

View File

@ -3337,7 +3337,12 @@ impl<'db> Type<'db> {
name: Name, name: Name,
policy: MemberLookupPolicy, policy: MemberLookupPolicy,
) -> PlaceAndQualifiers<'db> { ) -> PlaceAndQualifiers<'db> {
tracing::trace!("member_lookup_with_policy: {}.{}", self.display(db), name); let _span = tracing::trace_span!(
"member_lookup_with_policy",
ty = self.display(db).to_string(),
?name
)
.entered();
if name == "__class__" { if name == "__class__" {
return Place::bound(self.dunder_class(db)).into(); return Place::bound(self.dunder_class(db)).into();
} }
@ -10423,11 +10428,7 @@ static_assertions::assert_eq_size!(Type, [u8; 16]);
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;
use crate::db::tests::{TestDbBuilder, setup_db}; use crate::db::tests::{TestDbBuilder, setup_db};
use crate::place::{global_symbol, typing_extensions_symbol, typing_symbol}; use crate::place::{typing_extensions_symbol, typing_symbol};
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use ruff_db::system::DbWithWritableSystem as _;
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::PythonVersion; use ruff_python_ast::PythonVersion;
use test_case::test_case; use test_case::test_case;
@ -10468,62 +10469,62 @@ pub(crate) mod tests {
/// Inferring the result of a call-expression shouldn't need to re-run after /// Inferring the result of a call-expression shouldn't need to re-run after
/// a trivial change to the function's file (e.g. by adding a docstring to the function). /// a trivial change to the function's file (e.g. by adding a docstring to the function).
#[test] // #[test]
fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { // fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> {
let mut db = setup_db(); // let mut db = setup_db();
db.write_dedented( // db.write_dedented(
"src/foo.py", // "src/foo.py",
r#" // r#"
def foo() -> int: // def foo() -> int:
return 5 // return 5
"#, // "#,
)?; // )?;
db.write_dedented( // db.write_dedented(
"src/bar.py", // "src/bar.py",
r#" // r#"
from foo import foo // from foo import foo
a = foo() // a = foo()
"#, // "#,
)?; // )?;
let bar = system_path_to_file(&db, "src/bar.py")?; // let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a").place; // let a = global_symbol(&db, bar, "a").place;
assert_eq!( // assert_eq!(
a.expect_type(), // a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)]) // UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
); // );
// Add a docstring to foo to trigger a re-run. // // Add a docstring to foo to trigger a re-run.
// The bar-call site of foo should not be re-run because of that // // The bar-call site of foo should not be re-run because of that
db.write_dedented( // db.write_dedented(
"src/foo.py", // "src/foo.py",
r#" // r#"
def foo() -> int: // def foo() -> int:
"Computes a value" // "Computes a value"
return 5 // return 5
"#, // "#,
)?; // )?;
db.clear_salsa_events(); // db.clear_salsa_events();
let a = global_symbol(&db, bar, "a").place; // let a = global_symbol(&db, bar, "a").place;
assert_eq!( // assert_eq!(
a.expect_type(), // a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)]) // UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
); // );
let events = db.take_salsa_events(); // let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db); // let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; // let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call); // let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events); // assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events, false);
Ok(()) // Ok(())
} // }
/// All other tests also make sure that `Type::Todo` works as expected. This particular /// All other tests also make sure that `Type::Todo` works as expected. This particular
/// test makes sure that we handle `Todo` types correctly, even if they originate from /// test makes sure that we handle `Todo` types correctly, even if they originate from

View File

@ -2882,7 +2882,7 @@ impl<'db> ClassLiteral<'db> {
// `self.SOME_CONSTANT: Final = 1`, infer the type from the value // `self.SOME_CONSTANT: Final = 1`, infer the type from the value
// on the right-hand side. // on the right-hand side.
let inferred_ty = infer_expression_type(db, index.expression(value)); let inferred_ty = infer_expression_type(db, index.expression(value), true);
return Place::bound(inferred_ty).with_qualifiers(all_qualifiers); return Place::bound(inferred_ty).with_qualifiers(all_qualifiers);
} }
@ -2968,6 +2968,7 @@ impl<'db> ClassLiteral<'db> {
let inferred_ty = infer_expression_type( let inferred_ty = infer_expression_type(
db, db,
index.expression(assign.value(&module)), index.expression(assign.value(&module)),
true,
); );
union_of_inferred_types = union_of_inferred_types.add(inferred_ty); union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
@ -2995,6 +2996,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type( let iterable_ty = infer_expression_type(
db, db,
index.expression(for_stmt.iterable(&module)), index.expression(for_stmt.iterable(&module)),
true,
); );
// TODO: Potential diagnostics resulting from the iterable are currently not reported. // TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = let inferred_ty =
@ -3025,6 +3027,7 @@ impl<'db> ClassLiteral<'db> {
let context_ty = infer_expression_type( let context_ty = infer_expression_type(
db, db,
index.expression(with_item.context_expr(&module)), index.expression(with_item.context_expr(&module)),
true,
); );
let inferred_ty = if with_item.is_async() { let inferred_ty = if with_item.is_async() {
context_ty.aenter(db) context_ty.aenter(db)
@ -3058,6 +3061,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type( let iterable_ty = infer_expression_type(
db, db,
index.expression(comprehension.iterable(&module)), index.expression(comprehension.iterable(&module)),
true,
); );
// TODO: Potential diagnostics resulting from the iterable are currently not reported. // TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = let inferred_ty =

View File

@ -256,6 +256,7 @@ fn deferred_cycle_initial<'db>(
pub(crate) fn infer_expression_types<'db>( pub(crate) fn infer_expression_types<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
_break_cycle: bool,
) -> ExpressionInference<'db> { ) -> ExpressionInference<'db> {
let file = expression.file(db); let file = expression.file(db);
let module = parsed_module(db, file).load(db); let module = parsed_module(db, file).load(db);
@ -278,6 +279,7 @@ fn expression_cycle_recover<'db>(
_value: &ExpressionInference<'db>, _value: &ExpressionInference<'db>,
_count: u32, _count: u32,
_expression: Expression<'db>, _expression: Expression<'db>,
_break_cycle: bool,
) -> salsa::CycleRecoveryAction<ExpressionInference<'db>> { ) -> salsa::CycleRecoveryAction<ExpressionInference<'db>> {
salsa::CycleRecoveryAction::Iterate salsa::CycleRecoveryAction::Iterate
} }
@ -285,6 +287,7 @@ fn expression_cycle_recover<'db>(
fn expression_cycle_initial<'db>( fn expression_cycle_initial<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
_break_cycle: bool,
) -> ExpressionInference<'db> { ) -> ExpressionInference<'db> {
ExpressionInference::cycle_fallback(expression.scope(db)) ExpressionInference::cycle_fallback(expression.scope(db))
} }
@ -298,8 +301,9 @@ pub(super) fn infer_same_file_expression_type<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
parsed: &ParsedModuleRef, parsed: &ParsedModuleRef,
break_cycle: bool,
) -> Type<'db> { ) -> Type<'db> {
let inference = infer_expression_types(db, expression); let inference = infer_expression_types(db, expression, break_cycle);
inference.expression_type(expression.node_ref(db, parsed)) inference.expression_type(expression.node_ref(db, parsed))
} }
@ -314,19 +318,33 @@ pub(super) fn infer_same_file_expression_type<'db>(
pub(crate) fn infer_expression_type<'db>( pub(crate) fn infer_expression_type<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
break_cycle: bool,
) -> Type<'db> { ) -> Type<'db> {
let file = expression.file(db); let file = expression.file(db);
let module = parsed_module(db, file).load(db); let module = parsed_module(db, file).load(db);
// It's okay to call the "same file" version here because we're inside a salsa query. // It's okay to call the "same file" version here because we're inside a salsa query.
infer_same_file_expression_type(db, expression, &module) infer_same_file_expression_type(db, expression, &module, break_cycle)
} }
// #[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
// pub(crate) fn infer_expression_type_query<'db>(
// db: &'db dyn Db,
// expression: Expression<'db>,
// ) -> Type<'db> {
// let file = expression.file(db);
// let module = parsed_module(db, file).load(db);
// // It's okay to call the "same file" version here because we're inside a salsa query.
// infer_same_file_expression_type(db, expression, &module)
// }
fn single_expression_cycle_recover<'db>( fn single_expression_cycle_recover<'db>(
_db: &'db dyn Db, _db: &'db dyn Db,
_value: &Type<'db>, _value: &Type<'db>,
_count: u32, _count: u32,
_expression: Expression<'db>, _expression: Expression<'db>,
_break_cycle: bool,
) -> salsa::CycleRecoveryAction<Type<'db>> { ) -> salsa::CycleRecoveryAction<Type<'db>> {
salsa::CycleRecoveryAction::Iterate salsa::CycleRecoveryAction::Iterate
} }
@ -334,6 +352,7 @@ fn single_expression_cycle_recover<'db>(
fn single_expression_cycle_initial<'db>( fn single_expression_cycle_initial<'db>(
_db: &'db dyn Db, _db: &'db dyn Db,
_expression: Expression<'db>, _expression: Expression<'db>,
_break_cycle: bool,
) -> Type<'db> { ) -> Type<'db> {
Type::Never Type::Never
} }
@ -347,7 +366,7 @@ pub(crate) fn static_expression_truthiness<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
) -> Truthiness { ) -> Truthiness {
let inference = infer_expression_types(db, expression); let inference = infer_expression_types(db, expression, false);
if !inference.all_places_definitely_bound() { if !inference.all_places_definitely_bound() {
return Truthiness::Ambiguous; return Truthiness::Ambiguous;
@ -2676,7 +2695,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_definition(&mut self, node: impl Into<DefinitionNodeKey> + std::fmt::Debug + Copy) { fn infer_definition(&mut self, node: impl Into<DefinitionNodeKey> + std::fmt::Debug + Copy) {
let definition = self.index.expect_single_definition(node); let definition = self.index.expect_single_definition(node);
let result = infer_definition_types(self.db(), definition); let result = infer_definition_types(self.db(), definition);
self.extend_definition(result); self.extend_definition(&result);
} }
fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {
@ -5142,7 +5161,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.check_deprecated(alias, ty.inner); self.check_deprecated(alias, ty.inner);
} }
} }
self.extend_definition(inferred); self.extend_definition(&inferred);
} }
} }
} }
@ -5635,7 +5654,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
expression: &ast::Expr, expression: &ast::Expr,
standalone_expression: Expression<'db>, standalone_expression: Expression<'db>,
) -> Type<'db> { ) -> Type<'db> {
let types = infer_expression_types(self.db(), standalone_expression); let types = infer_expression_types(self.db(), standalone_expression, false);
self.extend_expression(types); self.extend_expression(types);
// Instead of calling `self.expression_type(expr)` after extending here, we get // Instead of calling `self.expression_type(expr)` after extending here, we get
@ -6069,6 +6088,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
builder.db(), builder.db(),
builder.index.expression(iter_expr), builder.index.expression(iter_expr),
builder.module(), builder.module(),
false,
) )
} else { } else {
builder.infer_standalone_expression(iter_expr) builder.infer_standalone_expression(iter_expr)
@ -6091,7 +6111,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut infer_iterable_type = || { let mut infer_iterable_type = || {
let expression = self.index.expression(iterable); let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db(), expression); let result = infer_expression_types(self.db(), expression, false);
// Two things are different if it's the first comprehension: // Two things are different if it's the first comprehension:
// (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope,
@ -6141,7 +6161,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if named.target.is_name_expr() { if named.target.is_name_expr() {
let definition = self.index.expect_single_definition(named); let definition = self.index.expect_single_definition(named);
let result = infer_definition_types(self.db(), definition); let result = infer_definition_types(self.db(), definition);
self.extend_definition(result); self.extend_definition(&result);
result.binding_type(definition) result.binding_type(definition)
} else { } else {
// For syntactically invalid targets, we still need to run type inference: // For syntactically invalid targets, we still need to run type inference:
@ -11667,7 +11687,7 @@ mod tests {
use ruff_db::diagnostic::Diagnostic; use ruff_db::diagnostic::Diagnostic;
use ruff_db::files::{File, system_path_to_file}; use ruff_db::files::{File, system_path_to_file};
use ruff_db::system::DbWithWritableSystem as _; use ruff_db::system::DbWithWritableSystem as _;
use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run}; use ruff_db::testing::assert_function_query_was_not_run;
use super::*; use super::*;
@ -12049,267 +12069,267 @@ mod tests {
Ok(()) Ok(())
} }
#[test] // #[test]
fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { // fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> { // fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap(); // let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main).load(db); // let ast = parsed_module(db, file_main).load(db);
// Get the second statement in `main.py` (x = …) and extract the expression // // Get the second statement in `main.py` (x = …) and extract the expression
// node on the right-hand side: // // node on the right-hand side:
let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value; // let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value;
let index = semantic_index(db, file_main); // let index = semantic_index(db, file_main);
index.expression(x_rhs_node.as_ref()) // index.expression(x_rhs_node.as_ref())
} // }
let mut db = setup_db(); // let mut db = setup_db();
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
def f(self): // def f(self):
self.attr: int | None = None // self.attr: int | None = None
"#, // "#,
)?; // )?;
db.write_dedented( // db.write_dedented(
"/src/main.py", // "/src/main.py",
r#" // r#"
from mod import C // from mod import C
x = C().attr // x = C().attr
"#, // "#,
)?; // )?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); // let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None");
// Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred // // Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
def f(self): // def f(self):
self.attr: str | None = None // self.attr: str | None = None
"#, // "#,
)?; // )?;
let events = { // let events = {
db.clear_salsa_events(); // db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events() // db.take_salsa_events()
}; // };
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); // assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
// Add a comment; this should not trigger the type of `x` to be re-inferred // // Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
def f(self): // def f(self):
# a comment! // # a comment!
self.attr: str | None = None // self.attr: str | None = None
"#, // "#,
)?; // )?;
let events = { // let events = {
db.clear_salsa_events(); // db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events() // db.take_salsa_events()
}; // };
assert_function_query_was_not_run( // assert_function_query_was_not_run(
&db, // &db,
infer_expression_types, // infer_expression_types,
x_rhs_expression(&db), // x_rhs_expression(&db),
&events, // &events,
); // );
Ok(()) // Ok(())
} // }
/// This test verifies that changing a class's declaration in a non-meaningful way (e.g. by adding a comment) // /// This test verifies that changing a class's declaration in a non-meaningful way (e.g. by adding a comment)
/// doesn't trigger type inference for expressions that depend on the class's members. // /// doesn't trigger type inference for expressions that depend on the class's members.
#[test] // #[test]
fn dependency_own_instance_member() -> anyhow::Result<()> { // fn dependency_own_instance_member() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> { // fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap(); // let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main).load(db); // let ast = parsed_module(db, file_main).load(db);
// Get the second statement in `main.py` (x = …) and extract the expression // // Get the second statement in `main.py` (x = …) and extract the expression
// node on the right-hand side: // // node on the right-hand side:
let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value; // let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value;
let index = semantic_index(db, file_main); // let index = semantic_index(db, file_main);
index.expression(x_rhs_node.as_ref()) // index.expression(x_rhs_node.as_ref())
} // }
let mut db = setup_db(); // let mut db = setup_db();
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
if random.choice([True, False]): // if random.choice([True, False]):
attr: int = 42 // attr: int = 42
else: // else:
attr: None = None // attr: None = None
"#, // "#,
)?; // )?;
db.write_dedented( // db.write_dedented(
"/src/main.py", // "/src/main.py",
r#" // r#"
from mod import C // from mod import C
x = C().attr // x = C().attr
"#, // "#,
)?; // )?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); // let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None");
// Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred // // Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
if random.choice([True, False]): // if random.choice([True, False]):
attr: str = "42" // attr: str = "42"
else: // else:
attr: None = None // attr: None = None
"#, // "#,
)?; // )?;
let events = { // let events = {
db.clear_salsa_events(); // db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events() // db.take_salsa_events()
}; // };
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); // assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
// Add a comment; this should not trigger the type of `x` to be re-inferred // // Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
# comment // # comment
if random.choice([True, False]): // if random.choice([True, False]):
attr: str = "42" // attr: str = "42"
else: // else:
attr: None = None // attr: None = None
"#, // "#,
)?; // )?;
let events = { // let events = {
db.clear_salsa_events(); // db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events() // db.take_salsa_events()
}; // };
assert_function_query_was_not_run( // assert_function_query_was_not_run(
&db, // &db,
infer_expression_types, // infer_expression_types,
x_rhs_expression(&db), // x_rhs_expression(&db),
&events, // &events,
); // );
Ok(()) // Ok(())
} // }
#[test] // #[test]
fn dependency_implicit_class_member() -> anyhow::Result<()> { // fn dependency_implicit_class_member() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> { // fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap(); // let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main).load(db); // let ast = parsed_module(db, file_main).load(db);
// Get the third statement in `main.py` (x = …) and extract the expression // // Get the third statement in `main.py` (x = …) and extract the expression
// node on the right-hand side: // // node on the right-hand side:
let x_rhs_node = &ast.syntax().body[2].as_assign_stmt().unwrap().value; // let x_rhs_node = &ast.syntax().body[2].as_assign_stmt().unwrap().value;
let index = semantic_index(db, file_main); // let index = semantic_index(db, file_main);
index.expression(x_rhs_node.as_ref()) // index.expression(x_rhs_node.as_ref())
} // }
let mut db = setup_db(); // let mut db = setup_db();
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
def __init__(self): // def __init__(self):
self.instance_attr: str = "24" // self.instance_attr: str = "24"
@classmethod // @classmethod
def method(cls): // def method(cls):
cls.class_attr: int = 42 // cls.class_attr: int = 42
"#, // "#,
)?; // )?;
db.write_dedented( // db.write_dedented(
"/src/main.py", // "/src/main.py",
r#" // r#"
from mod import C // from mod import C
C.method() // C.method()
x = C().class_attr // x = C().class_attr
"#, // "#,
)?; // )?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); // let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int");
// Change the type of `class_attr` to `str`; this should trigger the type of `x` to be re-inferred // // Change the type of `class_attr` to `str`; this should trigger the type of `x` to be re-inferred
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
def __init__(self): // def __init__(self):
self.instance_attr: str = "24" // self.instance_attr: str = "24"
@classmethod // @classmethod
def method(cls): // def method(cls):
cls.class_attr: str = "42" // cls.class_attr: str = "42"
"#, // "#,
)?; // )?;
let events = { // let events = {
db.clear_salsa_events(); // db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str");
db.take_salsa_events() // db.take_salsa_events()
}; // };
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); // assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
// Add a comment; this should not trigger the type of `x` to be re-inferred // // Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented( // db.write_dedented(
"/src/mod.py", // "/src/mod.py",
r#" // r#"
class C: // class C:
def __init__(self): // def __init__(self):
self.instance_attr: str = "24" // self.instance_attr: str = "24"
@classmethod // @classmethod
def method(cls): // def method(cls):
# comment // # comment
cls.class_attr: str = "42" // cls.class_attr: str = "42"
"#, // "#,
)?; // )?;
let events = { // let events = {
db.clear_salsa_events(); // db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").place.expect_type(); // let attr_ty = global_symbol(&db, file_main, "x").place.expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); // assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str");
db.take_salsa_events() // db.take_salsa_events()
}; // };
assert_function_query_was_not_run( // assert_function_query_was_not_run(
&db, // &db,
infer_expression_types, // infer_expression_types,
x_rhs_expression(&db), // x_rhs_expression(&db),
&events, // &events,
); // );
Ok(()) // Ok(())
} // }
} }

View File

@ -709,7 +709,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return None; return None;
} }
let inference = infer_expression_types(self.db, expression); let inference = infer_expression_types(self.db, expression, false);
let comparator_tuples = std::iter::once(&**left) let comparator_tuples = std::iter::once(&**left)
.chain(comparators) .chain(comparators)
@ -799,7 +799,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression: Expression<'db>, expression: Expression<'db>,
is_positive: bool, is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression); let inference = infer_expression_types(self.db, expression, false);
let callable_ty = inference.expression_type(&*expr_call.func); let callable_ty = inference.expression_type(&*expr_call.func);
@ -921,7 +921,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?; let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject); let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?; let ty = infer_same_file_expression_type(self.db, cls, self.module, false)
.to_instance(self.db)?;
Some(NarrowingConstraints::from_iter([(place, ty)])) Some(NarrowingConstraints::from_iter([(place, ty)]))
} }
@ -934,7 +935,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?; let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject); let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, value, self.module); let ty = infer_same_file_expression_type(self.db, value, self.module, false);
Some(NarrowingConstraints::from_iter([(place, ty)])) Some(NarrowingConstraints::from_iter([(place, ty)]))
} }
@ -963,7 +964,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression: Expression<'db>, expression: Expression<'db>,
is_positive: bool, is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression); let inference = infer_expression_types(self.db, expression, false);
let mut sub_constraints = expr_bool_op let mut sub_constraints = expr_bool_op
.values .values
.iter() .iter()

View File

@ -48,7 +48,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> {
"Unpacking target must be a list or tuple expression" "Unpacking target must be a list or tuple expression"
); );
let value_type = infer_expression_types(self.db(), value.expression()) let value_type = infer_expression_types(self.db(), value.expression(), false)
.expression_type(value.expression().node_ref(self.db(), self.module())); .expression_type(value.expression().node_ref(self.db(), self.module()));
let value_type = match value.kind() { let value_type = match value.kind() {