mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 05:20:49 -05:00
[ty] Handle optional errors in conformance workflow (#22647)
This commit is contained in:
@@ -40,7 +40,7 @@ import sys
|
||||
from dataclasses import dataclass
|
||||
from enum import Flag, StrEnum, auto
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from itertools import chain, groupby
|
||||
from operator import attrgetter, or_
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@@ -49,10 +49,12 @@ from typing import Any, Literal, Self, assert_never
|
||||
# The conformance tests include 4 types of errors:
|
||||
# 1. Required errors (E): The type checker must raise an error on this line
|
||||
# 2. Optional errors (E?): The type checker may raise an error on this line
|
||||
# 3. Tagged errors (E[tag]): The type checker must raise at most one error on any of the lines in a file with matching tags
|
||||
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or more errors on any of the tagged lines
|
||||
# This regex pattern parses the error lines in the conformance tests, but the following
|
||||
# implementation treats all errors as required errors.
|
||||
# 3. Tagged errors (E[tag]): The type checker must raise at most one error
|
||||
# on a set of lines with a matching tag
|
||||
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or
|
||||
# more errors on a set of lines with a matching tag
|
||||
# This regex pattern parses the error lines in the conformance tests,
|
||||
# but the following implementation currently ignores error tags.
|
||||
CONFORMANCE_ERROR_PATTERN = re.compile(
|
||||
r"""
|
||||
\#\s*E # "# E" begins each error
|
||||
@@ -74,7 +76,7 @@ CONFORMANCE_SUITE_COMMIT = os.environ.get("CONFORMANCE_SUITE_COMMIT", "main")
|
||||
CONFORMANCE_DIR_WITH_README = (
|
||||
f"https://github.com/python/typing/blob/{CONFORMANCE_SUITE_COMMIT}/conformance/"
|
||||
)
|
||||
CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "/tests/{filename}#L{line}"
|
||||
CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "tests/{filename}#L{line}"
|
||||
|
||||
|
||||
class Source(Flag):
|
||||
@@ -101,6 +103,21 @@ class Classification(StrEnum):
|
||||
return "True positives removed"
|
||||
|
||||
|
||||
class Change(StrEnum):
|
||||
ADDED = auto()
|
||||
REMOVED = auto()
|
||||
UNCHANGED = auto()
|
||||
|
||||
def into_title(self) -> str:
|
||||
match self:
|
||||
case Change.ADDED:
|
||||
return "Optional Diagnostics Added"
|
||||
case Change.REMOVED:
|
||||
return "Optional Diagnostics Removed"
|
||||
case Change.UNCHANGED:
|
||||
return "Optional Diagnostics Unchanged"
|
||||
|
||||
|
||||
@dataclass(kw_only=True, slots=True)
|
||||
class Position:
|
||||
line: int
|
||||
@@ -136,8 +153,11 @@ class Diagnostic:
|
||||
fingerprint: str | None
|
||||
location: Location
|
||||
source: Source
|
||||
optional: bool
|
||||
|
||||
def __post_init__(self, *args, **kwargs) -> None:
|
||||
# Remove check name prefix from description
|
||||
self.description = self.description.replace(f"{self.check_name}: ", "")
|
||||
# Escape pipe characters for GitHub markdown tables
|
||||
self.description = self.description.replace("|", "\\|")
|
||||
|
||||
@@ -173,6 +193,7 @@ class Diagnostic:
|
||||
),
|
||||
),
|
||||
source=source,
|
||||
optional=False,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -196,12 +217,6 @@ class GroupedDiagnostics:
|
||||
new: Diagnostic | None
|
||||
expected: Diagnostic | None
|
||||
|
||||
@property
|
||||
def changed(self) -> bool:
|
||||
return (Source.OLD in self.sources or Source.NEW in self.sources) and not (
|
||||
Source.OLD in self.sources and Source.NEW in self.sources
|
||||
)
|
||||
|
||||
@property
|
||||
def classification(self) -> Classification:
|
||||
if Source.NEW in self.sources and Source.EXPECTED in self.sources:
|
||||
@@ -213,6 +228,19 @@ class GroupedDiagnostics:
|
||||
else:
|
||||
return Classification.TRUE_NEGATIVE
|
||||
|
||||
@property
|
||||
def change(self) -> Change:
|
||||
if Source.NEW in self.sources and Source.OLD not in self.sources:
|
||||
return Change.ADDED
|
||||
elif Source.OLD in self.sources and Source.NEW not in self.sources:
|
||||
return Change.REMOVED
|
||||
else:
|
||||
return Change.UNCHANGED
|
||||
|
||||
@property
|
||||
def optional(self) -> bool:
|
||||
return self.expected is not None and self.expected.optional
|
||||
|
||||
def _render_row(self, diagnostic: Diagnostic):
|
||||
return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |"
|
||||
|
||||
@@ -296,6 +324,7 @@ def collect_expected_diagnostics(path: Path) -> list[Diagnostic]:
|
||||
),
|
||||
),
|
||||
source=Source.EXPECTED,
|
||||
optional=error.group("optional") is not None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -331,6 +360,7 @@ def collect_ty_diagnostics(
|
||||
return [
|
||||
Diagnostic.from_gitlab_output(dct, source=source)
|
||||
for dct in json.loads(process.stdout)
|
||||
if dct["severity"] == "major"
|
||||
]
|
||||
|
||||
|
||||
@@ -342,6 +372,7 @@ def group_diagnostics_by_key(
|
||||
*new,
|
||||
*expected,
|
||||
]
|
||||
|
||||
sorted_diagnostics = sorted(diagnostics, key=attrgetter("key"))
|
||||
|
||||
grouped = []
|
||||
@@ -364,7 +395,8 @@ def group_diagnostics_by_key(
|
||||
|
||||
|
||||
def compute_stats(
|
||||
grouped_diagnostics: list[GroupedDiagnostics], source: Source
|
||||
grouped_diagnostics: list[GroupedDiagnostics],
|
||||
source: Source,
|
||||
) -> Statistics:
|
||||
if source == source.EXPECTED:
|
||||
# ty currently raises a false positive here due to incomplete enum.Flag support
|
||||
@@ -387,6 +419,8 @@ def compute_stats(
|
||||
statistics.false_negatives += 1
|
||||
return statistics
|
||||
|
||||
grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional]
|
||||
|
||||
return reduce(increment, grouped_diagnostics, Statistics())
|
||||
|
||||
|
||||
@@ -397,11 +431,21 @@ def render_grouped_diagnostics(
|
||||
format: Literal["diff", "github"] = "diff",
|
||||
) -> str:
|
||||
if changed_only:
|
||||
grouped = [diag for diag in grouped if diag.changed]
|
||||
grouped = [
|
||||
diag for diag in grouped if diag.change in (Change.ADDED, Change.REMOVED)
|
||||
]
|
||||
|
||||
sorted_by_class = sorted(
|
||||
grouped,
|
||||
key=attrgetter("classification"),
|
||||
get_change = attrgetter("change")
|
||||
get_classification = attrgetter("classification")
|
||||
|
||||
optional_diagnostics = sorted(
|
||||
(diag for diag in grouped if diag.optional),
|
||||
key=get_change,
|
||||
reverse=True,
|
||||
)
|
||||
required_diagnostics = sorted(
|
||||
(diag for diag in grouped if not diag.optional),
|
||||
key=get_classification,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -419,17 +463,16 @@ def render_grouped_diagnostics(
|
||||
raise ValueError("format must be one of 'diff' or 'github'")
|
||||
|
||||
lines = []
|
||||
for classification, group in groupby(
|
||||
sorted_by_class, key=attrgetter("classification")
|
||||
for group, diagnostics in chain(
|
||||
groupby(required_diagnostics, key=get_classification),
|
||||
groupby(optional_diagnostics, key=get_change),
|
||||
):
|
||||
group = list(group)
|
||||
|
||||
lines.append(f"### {classification.into_title()}")
|
||||
lines.append(f"### {group.into_title()}")
|
||||
lines.extend(["", "<details>", ""])
|
||||
|
||||
lines.extend(header)
|
||||
|
||||
for diag in group:
|
||||
for diag in diagnostics:
|
||||
lines.append(diag.display(format=format))
|
||||
|
||||
lines.append(footer)
|
||||
@@ -481,7 +524,7 @@ def render_summary(grouped_diagnostics: list[GroupedDiagnostics]):
|
||||
new = compute_stats(grouped_diagnostics, source=Source.NEW)
|
||||
|
||||
assert new.true_positives > 0, (
|
||||
"Expected ty to have at least one true positive "
|
||||
"Expected ty to have at least one true positive.\n"
|
||||
f"Sample of grouped diagnostics: {grouped_diagnostics[:5]}"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user