diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 165a0a3b85..5869411fed 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -2021,8 +2021,8 @@ impl<'db> ClassLiteral<'db> { /// Returns a [`Span`] with the range of the class's header. /// /// See [`Self::header_range`] for more details. - pub(super) fn header_span(self, db: &'db dyn Db, module: &ParsedModuleRef) -> Span { - Span::from(self.file(db)).with_range(self.header_range(db, module)) + pub(super) fn header_span(self, db: &'db dyn Db) -> Span { + Span::from(self.file(db)).with_range(self.header_range(db)) } /// Returns the range of the class's "header": the class name @@ -2032,9 +2032,10 @@ impl<'db> ClassLiteral<'db> { /// class Foo(Bar, metaclass=Baz): ... /// ^^^^^^^^^^^^^^^^^^^^^^^ /// ``` - pub(super) fn header_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> TextRange { + pub(super) fn header_range(self, db: &'db dyn Db) -> TextRange { let class_scope = self.body_scope(db); - let class_node = class_scope.node(db).expect_class(module); + let module = parsed_module(db.upcast(), class_scope.file(db)).load(db.upcast()); + let class_node = class_scope.node(db).expect_class(&module); let class_name = &class_node.name; TextRange::new( class_name.start(), diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index c613374c47..f0ecd96e46 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -1788,7 +1788,7 @@ pub(super) fn report_implicit_return_type( or `typing_extensions.Protocol` are considered protocol classes", ); sub_diagnostic.annotate( - Annotation::primary(class.header_span(db, context.module())).message(format_args!( + Annotation::primary(class.header_span(db)).message(format_args!( "`Protocol` not present in `{class}`'s immediate bases", class = class.name(db) )), @@ -1908,7 +1908,7 @@ pub(crate) fn report_bad_argument_to_get_protocol_members( class.name(db) ), ); - class_def_diagnostic.annotate(Annotation::primary(class.header_span(db, context.module()))); + class_def_diagnostic.annotate(Annotation::primary(class.header_span(db))); diagnostic.sub(class_def_diagnostic); diagnostic.info( @@ -1971,7 +1971,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( ), ); class_def_diagnostic.annotate( - Annotation::primary(protocol.header_span(db, context.module())) + Annotation::primary(protocol.header_span(db)) .message(format_args!("`{class_name}` declared here")), ); diagnostic.sub(class_def_diagnostic); @@ -2002,7 +2002,7 @@ pub(crate) fn report_attempted_protocol_instantiation( format_args!("Protocol classes cannot be instantiated"), ); class_def_diagnostic.annotate( - Annotation::primary(protocol.header_span(db, context.module())) + Annotation::primary(protocol.header_span(db)) .message(format_args!("`{class_name}` declared as a protocol here")), ); diagnostic.sub(class_def_diagnostic); @@ -2016,9 +2016,7 @@ pub(crate) fn report_duplicate_bases( ) { let db = context.db(); - let Some(builder) = - context.report_lint(&DUPLICATE_BASE, class.header_range(db, context.module())) - else { + let Some(builder) = context.report_lint(&DUPLICATE_BASE, class.header_range(db)) else { return; };