From dfd236c02880bee98571559507d7890774fc053e Mon Sep 17 00:00:00 2001 From: Will Duke <41601410+WillDuke@users.noreply.github.com> Date: Sat, 17 Jan 2026 12:47:45 +0000 Subject: [PATCH] [ty] Handle optional errors in conformance workflow (#22647) --- scripts/conformance.py | 91 +++++++++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/scripts/conformance.py b/scripts/conformance.py index eda4368d2c..2b45a6da6e 100644 --- a/scripts/conformance.py +++ b/scripts/conformance.py @@ -40,7 +40,7 @@ import sys from dataclasses import dataclass from enum import Flag, StrEnum, auto from functools import reduce -from itertools import groupby +from itertools import chain, groupby from operator import attrgetter, or_ from pathlib import Path from textwrap import dedent @@ -49,10 +49,12 @@ from typing import Any, Literal, Self, assert_never # The conformance tests include 4 types of errors: # 1. Required errors (E): The type checker must raise an error on this line # 2. Optional errors (E?): The type checker may raise an error on this line -# 3. Tagged errors (E[tag]): The type checker must raise at most one error on any of the lines in a file with matching tags -# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or more errors on any of the tagged lines -# This regex pattern parses the error lines in the conformance tests, but the following -# implementation treats all errors as required errors. +# 3. Tagged errors (E[tag]): The type checker must raise at most one error +# on a set of lines with a matching tag +# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or +# more errors on a set of lines with a matching tag +# This regex pattern parses the error lines in the conformance tests, +# but the following implementation currently ignores error tags. CONFORMANCE_ERROR_PATTERN = re.compile( r""" \#\s*E # "# E" begins each error @@ -74,7 +76,7 @@ CONFORMANCE_SUITE_COMMIT = os.environ.get("CONFORMANCE_SUITE_COMMIT", "main") CONFORMANCE_DIR_WITH_README = ( f"https://github.com/python/typing/blob/{CONFORMANCE_SUITE_COMMIT}/conformance/" ) -CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "/tests/{filename}#L{line}" +CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "tests/{filename}#L{line}" class Source(Flag): @@ -101,6 +103,21 @@ class Classification(StrEnum): return "True positives removed" +class Change(StrEnum): + ADDED = auto() + REMOVED = auto() + UNCHANGED = auto() + + def into_title(self) -> str: + match self: + case Change.ADDED: + return "Optional Diagnostics Added" + case Change.REMOVED: + return "Optional Diagnostics Removed" + case Change.UNCHANGED: + return "Optional Diagnostics Unchanged" + + @dataclass(kw_only=True, slots=True) class Position: line: int @@ -136,8 +153,11 @@ class Diagnostic: fingerprint: str | None location: Location source: Source + optional: bool def __post_init__(self, *args, **kwargs) -> None: + # Remove check name prefix from description + self.description = self.description.replace(f"{self.check_name}: ", "") # Escape pipe characters for GitHub markdown tables self.description = self.description.replace("|", "\\|") @@ -173,6 +193,7 @@ class Diagnostic: ), ), source=source, + optional=False, ) @property @@ -196,12 +217,6 @@ class GroupedDiagnostics: new: Diagnostic | None expected: Diagnostic | None - @property - def changed(self) -> bool: - return (Source.OLD in self.sources or Source.NEW in self.sources) and not ( - Source.OLD in self.sources and Source.NEW in self.sources - ) - @property def classification(self) -> Classification: if Source.NEW in self.sources and Source.EXPECTED in self.sources: @@ -213,6 +228,19 @@ class GroupedDiagnostics: else: return Classification.TRUE_NEGATIVE + @property + def change(self) -> Change: + if Source.NEW in self.sources and Source.OLD not in self.sources: + return Change.ADDED + elif Source.OLD in self.sources and Source.NEW not in self.sources: + return Change.REMOVED + else: + return Change.UNCHANGED + + @property + def optional(self) -> bool: + return self.expected is not None and self.expected.optional + def _render_row(self, diagnostic: Diagnostic): return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |" @@ -296,6 +324,7 @@ def collect_expected_diagnostics(path: Path) -> list[Diagnostic]: ), ), source=Source.EXPECTED, + optional=error.group("optional") is not None, ) ) @@ -331,6 +360,7 @@ def collect_ty_diagnostics( return [ Diagnostic.from_gitlab_output(dct, source=source) for dct in json.loads(process.stdout) + if dct["severity"] == "major" ] @@ -342,6 +372,7 @@ def group_diagnostics_by_key( *new, *expected, ] + sorted_diagnostics = sorted(diagnostics, key=attrgetter("key")) grouped = [] @@ -364,7 +395,8 @@ def group_diagnostics_by_key( def compute_stats( - grouped_diagnostics: list[GroupedDiagnostics], source: Source + grouped_diagnostics: list[GroupedDiagnostics], + source: Source, ) -> Statistics: if source == source.EXPECTED: # ty currently raises a false positive here due to incomplete enum.Flag support @@ -387,6 +419,8 @@ def compute_stats( statistics.false_negatives += 1 return statistics + grouped_diagnostics = [diag for diag in grouped_diagnostics if not diag.optional] + return reduce(increment, grouped_diagnostics, Statistics()) @@ -397,11 +431,21 @@ def render_grouped_diagnostics( format: Literal["diff", "github"] = "diff", ) -> str: if changed_only: - grouped = [diag for diag in grouped if diag.changed] + grouped = [ + diag for diag in grouped if diag.change in (Change.ADDED, Change.REMOVED) + ] - sorted_by_class = sorted( - grouped, - key=attrgetter("classification"), + get_change = attrgetter("change") + get_classification = attrgetter("classification") + + optional_diagnostics = sorted( + (diag for diag in grouped if diag.optional), + key=get_change, + reverse=True, + ) + required_diagnostics = sorted( + (diag for diag in grouped if not diag.optional), + key=get_classification, reverse=True, ) @@ -419,17 +463,16 @@ def render_grouped_diagnostics( raise ValueError("format must be one of 'diff' or 'github'") lines = [] - for classification, group in groupby( - sorted_by_class, key=attrgetter("classification") + for group, diagnostics in chain( + groupby(required_diagnostics, key=get_classification), + groupby(optional_diagnostics, key=get_change), ): - group = list(group) - - lines.append(f"### {classification.into_title()}") + lines.append(f"### {group.into_title()}") lines.extend(["", "
", ""]) lines.extend(header) - for diag in group: + for diag in diagnostics: lines.append(diag.display(format=format)) lines.append(footer) @@ -481,7 +524,7 @@ def render_summary(grouped_diagnostics: list[GroupedDiagnostics]): new = compute_stats(grouped_diagnostics, source=Source.NEW) assert new.true_positives > 0, ( - "Expected ty to have at least one true positive " + "Expected ty to have at least one true positive.\n" f"Sample of grouped diagnostics: {grouped_diagnostics[:5]}" )