[ty] Fix classification of module in import x as y (#22175)

This commit is contained in:
Micha Reiser
2025-12-24 18:25:29 +01:00
committed by GitHub
parent eef403f6cf
commit ded4d4bbe9

View File

@@ -17,9 +17,6 @@
//! TODO: Need to properly handle Annotated expressions. All type arguments other
//! than the first should be treated as value expressions, not as type expressions.
//!
//! TODO: An identifier that resolves to a parameter when used within a function
//! should be classified as a parameter, selfParameter, or clsParameter token.
//!
//! TODO: Properties (or perhaps more generally, descriptor objects?) should be
//! classified as property tokens rather than just variables.
//!
@@ -230,6 +227,11 @@ impl<'db> SemanticTokenVisitor<'db> {
modifiers: SemanticTokenModifier,
) {
let range = ranged.range();
if range.is_empty() {
return;
}
// Only emit tokens that intersect with the range filter, if one is specified
if let Some(range_filter) = self.range_filter {
// Only include ranges that have a non-empty overlap. Adjacent ranges
@@ -707,15 +709,15 @@ impl SourceOrderVisitor<'_> for SemanticTokenVisitor<'_> {
}
ast::Stmt::Import(import) => {
for alias in &import.names {
// Create separate tokens for each part of a dotted module name
self.add_dotted_name_tokens(&alias.name, SemanticTokenType::Namespace);
if let Some(asname) = &alias.asname {
self.add_token(
asname.range(),
SemanticTokenType::Namespace,
SemanticTokenModifier::empty(),
);
} else {
// Create separate tokens for each part of a dotted module name
self.add_dotted_name_tokens(&alias.name, SemanticTokenType::Namespace);
}
}
}
@@ -1131,7 +1133,7 @@ mod tests {
use ty_project::ProjectMetadata;
#[test]
fn test_semantic_tokens_basic() {
fn semantic_tokens_basic() {
let test = SemanticTokenTest::new("def foo(): pass");
let tokens = test.highlight_file();
@@ -1142,7 +1144,7 @@ mod tests {
}
#[test]
fn test_semantic_tokens_class() {
fn semantic_tokens_class() {
let test = SemanticTokenTest::new("class MyClass: pass");
let tokens = test.highlight_file();
@@ -1153,7 +1155,7 @@ mod tests {
}
#[test]
fn test_semantic_tokens_class_args() {
fn semantic_tokens_class_args() {
// This used to cause a panic because of an incorrect
// insertion-order when visiting arguments inside
// class definitions.
@@ -1169,7 +1171,7 @@ mod tests {
}
#[test]
fn test_semantic_tokens_variables() {
fn semantic_tokens_variables() {
let test = SemanticTokenTest::new(
"
x = 42
@@ -1188,7 +1190,7 @@ y = 'hello'
}
#[test]
fn test_semantic_tokens_walrus() {
fn semantic_tokens_walrus() {
let test = SemanticTokenTest::new(
"
if x := 42:
@@ -1207,7 +1209,7 @@ if x := 42:
}
#[test]
fn test_semantic_tokens_self_parameter() {
fn semantic_tokens_self_parameter() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1238,7 +1240,7 @@ class MyClass:
}
#[test]
fn test_semantic_tokens_cls_parameter() {
fn semantic_tokens_cls_parameter() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1261,7 +1263,7 @@ class MyClass:
}
#[test]
fn test_semantic_tokens_staticmethod_parameter() {
fn semantic_tokens_staticmethod_parameter() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1282,7 +1284,7 @@ class MyClass:
}
#[test]
fn test_semantic_tokens_custom_self_cls_names() {
fn semantic_tokens_custom_self_cls_names() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1317,7 +1319,7 @@ class MyClass:
}
#[test]
fn test_semantic_tokens_modifiers() {
fn semantic_tokens_modifiers() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1338,7 +1340,7 @@ class MyClass:
}
#[test]
fn test_semantic_classification_vs_heuristic() {
fn semantic_classification_vs_heuristic() {
let test = SemanticTokenTest::new(
"
import sys
@@ -1372,7 +1374,7 @@ z = sys.version
}
#[test]
fn test_builtin_constants() {
fn builtin_constants() {
let test = SemanticTokenTest::new(
"
x = True
@@ -1394,7 +1396,7 @@ z = None
}
#[test]
fn test_builtin_constants_in_expressions() {
fn builtin_constants_in_expressions() {
let test = SemanticTokenTest::new(
"
def check(value):
@@ -1422,7 +1424,7 @@ result = check(None)
}
#[test]
fn test_builtin_types() {
fn builtin_types() {
let test = SemanticTokenTest::new(
r#"
type U = str | int
@@ -1467,7 +1469,7 @@ result = check(None)
}
#[test]
fn test_semantic_tokens_range() {
fn semantic_tokens_range() {
let test = SemanticTokenTest::new(
"
def function1():
@@ -1532,7 +1534,7 @@ def function2():
/// When a token starts right at where the requested range ends,
/// don't include it in the semantic tokens.
#[test]
fn test_semantic_tokens_range_excludes_boundary_tokens() {
fn semantic_tokens_range_excludes_boundary_tokens() {
let test = SemanticTokenTest::new(
"
x = 1
@@ -1555,7 +1557,7 @@ z = 3
}
#[test]
fn test_dotted_module_names() {
fn dotted_module_names() {
let test = SemanticTokenTest::new(
"
import os.path
@@ -1582,7 +1584,7 @@ from collections.abc import Mapping
}
#[test]
fn test_module_type_classification() {
fn module_type_classification() {
let test = SemanticTokenTest::new(
"
import os
@@ -1610,7 +1612,7 @@ y = sys
}
#[test]
fn test_import_classification() {
fn import_classification() {
let test = SemanticTokenTest::new(
"
from os import path
@@ -1641,7 +1643,7 @@ from mymodule import CONSTANT, my_function, MyClass
}
#[test]
fn test_str_annotation() {
fn str_annotation() {
let test = SemanticTokenTest::new(
r#"
x: int = 1
@@ -1685,7 +1687,7 @@ w5: "float
}
#[test]
fn test_attribute_classification() {
fn attribute_classification() {
let test = SemanticTokenTest::new(
"
import os
@@ -1759,7 +1761,7 @@ u = List.__name__ # __name__ should be variable
}
#[test]
fn test_attribute_fallback_classification() {
fn attribute_fallback_classification() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1790,7 +1792,7 @@ y = obj.unknown_attr # Should fall back to variable
}
#[test]
fn test_constant_name_detection() {
fn constant_name_detection() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -1837,7 +1839,7 @@ w = obj.A # Should not have readonly modifier (length == 1)
}
#[test]
fn test_type_annotations() {
fn type_annotations() {
let test = SemanticTokenTest::new(
r#"
from typing import List, Optional
@@ -2325,7 +2327,7 @@ class MyClass:
}
#[test]
fn test_debug_int_classification() {
fn debug_int_classification() {
let test = SemanticTokenTest::new(
"
x: int = 42
@@ -2342,7 +2344,7 @@ x: int = 42
}
#[test]
fn test_debug_user_defined_type_classification() {
fn debug_user_defined_type_classification() {
let test = SemanticTokenTest::new(
"
class MyClass:
@@ -2363,7 +2365,7 @@ x: MyClass = MyClass()
}
#[test]
fn test_type_annotation_vs_variable_classification() {
fn type_annotation_vs_variable_classification() {
let test = SemanticTokenTest::new(
"
from typing import List, Optional
@@ -2413,7 +2415,7 @@ def test_function(param: int, other: MyClass) -> Optional[List[str]]:
}
#[test]
fn test_protocol_types_in_annotations() {
fn protocol_types_in_annotations() {
let test = SemanticTokenTest::new(
"
from typing import Protocol
@@ -2444,7 +2446,7 @@ def test_function(param: MyProtocol) -> None:
}
#[test]
fn test_protocol_type_annotation_vs_value_context() {
fn protocol_type_annotation_vs_value_context() {
let test = SemanticTokenTest::new(
"
from typing import Protocol
@@ -2530,7 +2532,7 @@ def test_function(param: my_type_alias): ...
}
#[test]
fn test_type_parameters_pep695() {
fn type_parameters_pep695() {
let test = SemanticTokenTest::new(
"
# Test Python 3.12 PEP 695 type parameter syntax
@@ -2654,7 +2656,7 @@ class BoundedContainer[T: int, U = str]:
}
#[test]
fn test_type_parameters_usage_in_function_body() {
fn type_parameters_usage_in_function_body() {
let test = SemanticTokenTest::new(
"
def generic_function[T](value: T) -> T:
@@ -2683,7 +2685,7 @@ def generic_function[T](value: T) -> T:
}
#[test]
fn test_decorator_classification() {
fn decorator_classification() {
let test = SemanticTokenTest::new(
r#"
@staticmethod
@@ -2713,7 +2715,7 @@ class MyClass:
}
#[test]
fn test_constant_variations() {
fn constant_variations() {
let test = SemanticTokenTest::new(
r#"
A = 1
@@ -2756,7 +2758,7 @@ A_1 = 1
}
#[test]
fn test_implicitly_concatenated_strings() {
fn implicitly_concatenated_strings() {
let test = SemanticTokenTest::new(
r#"x = "hello" "world"
y = ("multi"
@@ -2783,7 +2785,7 @@ z = 'single' "mixed" 'quotes'"#,
}
#[test]
fn test_bytes_literals() {
fn bytes_literals() {
let test = SemanticTokenTest::new(
r#"x = b"hello" b"world"
y = (b"multi"
@@ -2810,7 +2812,7 @@ z = b'single' b"mixed" b'quotes'"#,
}
#[test]
fn test_mixed_string_and_bytes_literals() {
fn mixed_string_and_bytes_literals() {
let test = SemanticTokenTest::new(
r#"# Test mixed string and bytes literals
string_concat = "hello" "world"
@@ -2846,7 +2848,7 @@ regular_bytes = b"just bytes""#,
}
#[test]
fn test_fstring_with_mixed_literals() {
fn fstring_with_mixed_literals() {
let test = SemanticTokenTest::new(
r#"
# Test f-strings with various literal types
@@ -2898,7 +2900,7 @@ complex_fstring = f"User: {name.upper()}, Count: {len(data)}, Hex: {value:x}"
}
#[test]
fn test_nonlocal_and_global_statements() {
fn nonlocal_and_global_statements() {
let test = SemanticTokenTest::new(
r#"
x = "global_value"
@@ -2960,7 +2962,7 @@ def outer():
}
#[test]
fn test_nonlocal_global_edge_cases() {
fn nonlocal_global_edge_cases() {
let test = SemanticTokenTest::new(
r#"
# Single variable statements
@@ -3000,7 +3002,7 @@ def test():
}
#[test]
fn test_pattern_matching() {
fn pattern_matching() {
let test = SemanticTokenTest::new(
r#"
def process_data(data):
@@ -3056,7 +3058,7 @@ def process_data(data):
}
#[test]
fn test_exception_handlers() {
fn exception_handlers() {
let test = SemanticTokenTest::new(
r#"
try:
@@ -3095,7 +3097,7 @@ finally:
}
#[test]
fn test_self_attribute_expression() {
fn self_attribute_expression() {
let test = SemanticTokenTest::new(
r#"
from typing import Self
@@ -3135,7 +3137,7 @@ class C:
}
#[test]
fn test_augmented_assignment() {
fn augmented_assignment() {
let test = SemanticTokenTest::new(
r#"
x = 0
@@ -3154,7 +3156,7 @@ x += 1
}
#[test]
fn test_type_alias() {
fn type_alias() {
let test = SemanticTokenTest::new("type MyList[T] = list[T]");
let tokens = test.highlight_file();
@@ -3168,7 +3170,7 @@ x += 1
}
#[test]
fn test_for_stmt() {
fn for_stmt() {
let test = SemanticTokenTest::new(
r#"
for item in []:
@@ -3190,7 +3192,7 @@ else:
}
#[test]
fn test_with_stmt() {
fn with_stmt() {
let test = SemanticTokenTest::new(
r#"
with open("file.txt") as f:
@@ -3210,7 +3212,7 @@ with open("file.txt") as f:
}
#[test]
fn test_comprehensions() {
fn comprehensions() {
let test = SemanticTokenTest::new(
r#"
list_comp = [x for x in range(10) if x % 2 == 0]
@@ -3255,7 +3257,7 @@ generator = (x for x in range(10))
/// Regression test for <https://github.com/astral-sh/ty/issues/1406>
#[test]
fn test_invalid_kwargs() {
fn invalid_kwargs() {
let test = SemanticTokenTest::new(
r#"
def foo(self, **key, value=10):
@@ -3274,6 +3276,24 @@ def foo(self, **key, value=10):
"#);
}
#[test]
fn import_as() {
let test = SemanticTokenTest::new(
r#"
import pathlib as path
from pathlib import Path
"#,
);
let tokens = test.highlight_file();
assert_snapshot!(test.to_snapshot(&tokens), @r#"
"pathlib" @ 8..15: Namespace
"path" @ 19..23: Namespace
"pathlib" @ 29..36: Namespace
"Path" @ 44..48: Class
"#);
}
pub(super) struct SemanticTokenTest {
pub(super) db: ty_project::TestDb,
file: File,