[ty] Handle optional errors in conformance workflow (#22647)

This commit is contained in:
Will Duke
2026-01-17 12:47:45 +00:00
committed by GitHub
parent ebf7d0cd2f
commit dfd236c028

View File

@@ -40,7 +40,7 @@ import sys
from dataclasses import dataclass
from enum import Flag, StrEnum, auto
from functools import reduce
from itertools import groupby
from itertools import chain, groupby
from operator import attrgetter, or_
from pathlib import Path
from textwrap import dedent
@@ -49,10 +49,12 @@ from typing import Any, Literal, Self, assert_never
# The conformance tests include 4 types of errors:
# 1. Required errors (E): The type checker must raise an error on this line
# 2. Optional errors (E?): The type checker may raise an error on this line
# 3. Tagged errors (E[tag]): The type checker must raise at most one error on any of the lines in a file with matching tags
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or more errors on any of the tagged lines
# This regex pattern parses the error lines in the conformance tests, but the following
# implementation treats all errors as required errors.
# 3. Tagged errors (E[tag]): The type checker must raise at most one error
# on a set of lines with a matching tag
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or
# more errors on a set of lines with a matching tag
# This regex pattern parses the error lines in the conformance tests,
# but the following implementation currently ignores error tags.
CONFORMANCE_ERROR_PATTERN = re.compile(
r"""
\#\s*E # "# E" begins each error
@@ -74,7 +76,7 @@ CONFORMANCE_SUITE_COMMIT = os.environ.get("CONFORMANCE_SUITE_COMMIT", "main")
CONFORMANCE_DIR_WITH_README = (
f"https://github.com/python/typing/blob/{CONFORMANCE_SUITE_COMMIT}/conformance/"
)
CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "/tests/{filename}#L{line}"
CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "tests/{filename}#L{line}"
class Source(Flag):
@@ -101,6 +103,21 @@ class Classification(StrEnum):
return "True positives removed"
class Change(StrEnum):
ADDED = auto()
REMOVED = auto()
UNCHANGED = auto()
def into_title(self) -> str:
match self:
case Change.ADDED:
return "Optional Diagnostics Added"
case Change.REMOVED:
return "Optional Diagnostics Removed"
case Change.UNCHANGED:
return "Optional Diagnostics Unchanged"
@dataclass(kw_only=True, slots=True)
class Position:
line: int
@@ -136,8 +153,11 @@ class Diagnostic:
fingerprint: str | None
location: Location
source: Source
optional: bool
def __post_init__(self, *args, **kwargs) -> None:
# Remove check name prefix from description
self.description = self.description.replace(f"{self.check_name}: ", "")
# Escape pipe characters for GitHub markdown tables
self.description = self.description.replace("|", "\\|")
@@ -173,6 +193,7 @@ class Diagnostic:
),
),
source=source,
optional=False,
)
@property
@@ -196,12 +217,6 @@ class GroupedDiagnostics:
new: Diagnostic | None
expected: Diagnostic | None
@property
def changed(self) -> bool:
return (Source.OLD in self.sources or Source.NEW in self.sources) and not (
Source.OLD in self.sources and Source.NEW in self.sources
)
@property
def classification(self) -> Classification:
if Source.NEW in self.sources and Source.EXPECTED in self.sources:
@@ -213,6 +228,19 @@ class GroupedDiagnostics:
else:
return Classification.TRUE_NEGATIVE
@property
def change(self) -> Change:
if Source.NEW in self.sources and Source.OLD not in self.sources:
return Change.ADDED
elif Source.OLD in self.sources and Source.NEW not in self.sources:
return Change.REMOVED
else:
return Change.UNCHANGED
@property
def optional(self) -> bool:
return self.expected is not None and self.expected.optional
def _render_row(self, diagnostic: Diagnostic):
return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |"
@@ -296,6 +324,7 @@ def collect_expected_diagnostics(path: Path) -> list[Diagnostic]:
),
),
source=Source.EXPECTED,
optional=error.group("optional") is not None,
)
)
@@ -331,6 +360,7 @@ def collect_ty_diagnostics(
return [
Diagnostic.from_gitlab_output(dct, source=source)
for dct in json.loads(process.stdout)
if dct["severity"] == "major"
]
@@ -342,6 +372,7 @@ def group_diagnostics_by_key(
*new,
*expected,
]
sorted_diagnostics = sorted(diagnostics, key=attrgetter("key"))
grouped = []
@@ -364,7 +395,8 @@ def group_diagnostics_by_key(
def compute_stats(
grouped_diagnostics: list[GroupedDiagnostics], source: Source
grouped_diagnostics: list[GroupedDiagnostics],
source: Source,
) -> Statistics:
if source == source.EXPECTED:
# ty currently raises a false positive here due to incomplete enum.Flag support
@@ -387,6 +419,8 @@ def compute_stats(
statistics.false_negatives += 1
return statistics
grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional]
return reduce(increment, grouped_diagnostics, Statistics())
@@ -397,11 +431,21 @@ def render_grouped_diagnostics(
format: Literal["diff", "github"] = "diff",
) -> str:
if changed_only:
grouped = [diag for diag in grouped if diag.changed]
grouped = [
diag for diag in grouped if diag.change in (Change.ADDED, Change.REMOVED)
]
sorted_by_class = sorted(
grouped,
key=attrgetter("classification"),
get_change = attrgetter("change")
get_classification = attrgetter("classification")
optional_diagnostics = sorted(
(diag for diag in grouped if diag.optional),
key=get_change,
reverse=True,
)
required_diagnostics = sorted(
(diag for diag in grouped if not diag.optional),
key=get_classification,
reverse=True,
)
@@ -419,17 +463,16 @@ def render_grouped_diagnostics(
raise ValueError("format must be one of 'diff' or 'github'")
lines = []
for classification, group in groupby(
sorted_by_class, key=attrgetter("classification")
for group, diagnostics in chain(
groupby(required_diagnostics, key=get_classification),
groupby(optional_diagnostics, key=get_change),
):
group = list(group)
lines.append(f"### {classification.into_title()}")
lines.append(f"### {group.into_title()}")
lines.extend(["", "<details>", ""])
lines.extend(header)
for diag in group:
for diag in diagnostics:
lines.append(diag.display(format=format))
lines.append(footer)
@@ -481,7 +524,7 @@ def render_summary(grouped_diagnostics: list[GroupedDiagnostics]):
new = compute_stats(grouped_diagnostics, source=Source.NEW)
assert new.true_positives > 0, (
"Expected ty to have at least one true positive "
"Expected ty to have at least one true positive.\n"
f"Sample of grouped diagnostics: {grouped_diagnostics[:5]}"
)