[ty] Add a conformance script to compare ty diagnostics with expected errors (#22231)

This commit is contained in:
Will Duke
2026-01-13 22:19:41 +00:00
committed by GitHub
parent 9a676bbeb7
commit 4ebf10cf1b

607
scripts/conformance.py Normal file
View File

@@ -0,0 +1,607 @@
"""
Run typing conformance tests and compare results between two ty versions.
By default, this script will use `uv` to run the latest version of ty
as the new version with `uvx ty@latest`. This requires `uv` to be installed
and available in the system PATH.
If CONFORMANCE_SUITE_COMMIT is set, the hash will be used to create
links to the corresponding line in the conformance repository for each
diagnostic. Otherwise, it will default to `main'.
Examples:
# Compare two specific ty versions
%(prog)s --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7
# Use local ty builds
%(prog)s --old-ty ./target/debug/ty-old --new-ty ./target/debug/ty-new
# Custom test directory
%(prog)s --target-path custom/tests --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7
# Show all diagnostics (not just changed ones)
%(prog)s --all --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7
# Show a diff with local paths to the test directory instead of table of links
%(prog)s --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7 --format diff
"""
from __future__ import annotations
import argparse
import json
import os
import re
import subprocess
import sys
from dataclasses import dataclass
from enum import Flag, StrEnum, auto
from functools import reduce
from itertools import groupby
from operator import attrgetter, or_
from pathlib import Path
from textwrap import dedent
from typing import Any, Literal, Self
# The conformance tests include 4 types of errors:
# 1. Required errors (E): The type checker must raise an error on this line
# 2. Optional errors (E?): The type checker may raise an error on this line
# 3. Tagged errors (E[tag]): The type checker must raise at most one error on any of the lines in a file with matching tags
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or more errors on any of the tagged lines
# This regex pattern parses the error lines in the conformance tests, but the following
# implementation treats all errors as required errors.
CONFORMANCE_ERROR_PATTERN = re.compile(
r"""
\#\s*E # "# E" begins each error
(?P<optional>\?)? # Optional '?' (E?) indicates that an error is optional
(?: # An optional tag for errors that may appear on multiple lines at most once
\[
(?P<tag>[^+\]]+) # identifier
(?P<multi>\+)? # '+' indicates that an error may occur more than once on tagged lines
\]
)?
(?:
\s*:\s*(?P<description>.*) # optional description
)?
""",
re.VERBOSE,
)
CONFORMANCE_URL = "https://github.com/python/typing/blob/{conformance_suite_commit}/conformance/tests/{filename}#L{line}"
CONFORMANCE_SUITE_COMMIT = os.environ.get("CONFORMANCE_SUITE_COMMIT", "main")
class Source(Flag):
OLD = auto()
NEW = auto()
EXPECTED = auto()
class Classification(StrEnum):
TRUE_POSITIVE = auto()
FALSE_POSITIVE = auto()
TRUE_NEGATIVE = auto()
FALSE_NEGATIVE = auto()
def into_title(self) -> str:
match self:
case Classification.TRUE_POSITIVE:
return "True positives added 🎉"
case Classification.FALSE_POSITIVE:
return "False positives added 🫤"
case Classification.TRUE_NEGATIVE:
return "False positives removed 🎉"
case Classification.FALSE_NEGATIVE:
return "True positives removed 🫤"
@dataclass(kw_only=True, slots=True)
class Position:
line: int
column: int
@dataclass(kw_only=True, slots=True)
class Positions:
begin: Position
end: Position
@dataclass(kw_only=True, slots=True)
class Location:
path: str
positions: Positions
def as_link(self) -> str:
file = os.path.basename(self.path)
link = CONFORMANCE_URL.format(
conformance_suite_commit=CONFORMANCE_SUITE_COMMIT,
filename=file,
line=self.positions.begin.line,
)
return f"[{file}:{self.positions.begin.line}:{self.positions.begin.column}]({link})"
@dataclass(kw_only=True, slots=True)
class Diagnostic:
check_name: str
description: str
severity: str
fingerprint: str | None
location: Location
source: Source
def __str__(self) -> str:
return (
f"{self.location.path}:{self.location.positions.begin.line}:"
f"{self.location.positions.begin.column}: "
f"{self.severity_for_display}[{self.check_name}] {self.description}"
)
@classmethod
def from_gitlab_output(
cls,
dct: dict[str, Any],
source: Source,
) -> Self:
return cls(
check_name=dct["check_name"],
description=dct["description"],
severity=dct["severity"],
fingerprint=dct["fingerprint"],
location=Location(
path=dct["location"]["path"],
positions=Positions(
begin=Position(
line=dct["location"]["positions"]["begin"]["line"],
column=dct["location"]["positions"]["begin"]["column"],
),
end=Position(
line=dct["location"]["positions"]["end"]["line"],
column=dct["location"]["positions"]["end"]["column"],
),
),
),
source=source,
)
@property
def key(self) -> str:
"""Key to group diagnostics by path and beginning line."""
return f"{self.location.path}:{self.location.positions.begin.line}"
@property
def severity_for_display(self) -> str:
return {
"major": "error",
"minor": "warning",
}.get(self.severity, "unknown")
@dataclass(kw_only=True, slots=True)
class GroupedDiagnostics:
key: str
sources: Source
old: Diagnostic | None
new: Diagnostic | None
expected: Diagnostic | None
@property
def changed(self) -> bool:
return (Source.OLD in self.sources or Source.NEW in self.sources) and not (
Source.OLD in self.sources and Source.NEW in self.sources
)
@property
def classification(self) -> Classification:
if Source.NEW in self.sources and Source.EXPECTED in self.sources:
return Classification.TRUE_POSITIVE
elif Source.NEW in self.sources and Source.EXPECTED not in self.sources:
return Classification.FALSE_POSITIVE
elif Source.EXPECTED in self.sources:
return Classification.FALSE_NEGATIVE
else:
return Classification.TRUE_NEGATIVE
def _render_row(self, diagnostic: Diagnostic):
return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |"
def _render_diff(self, diagnostic: Diagnostic, *, removed: bool = False):
sign = "-" if removed else "+"
return f"{sign} {diagnostic}"
def display(self, format: Literal["diff", "github"]) -> str:
match self.classification:
case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE:
assert self.new is not None
return (
self._render_diff(self.new)
if format == "diff"
else self._render_row(self.new)
)
case Classification.FALSE_NEGATIVE | Classification.TRUE_NEGATIVE:
diagnostic = self.old or self.expected
assert diagnostic is not None
return (
self._render_diff(diagnostic, removed=True)
if format == "diff"
else self._render_row(diagnostic)
)
case _:
raise ValueError(f"Unexpected classification: {self.classification}")
@dataclass(kw_only=True, slots=True)
class Statistics:
true_positives: int = 0
false_positives: int = 0
false_negatives: int = 0
@property
def precision(self) -> float:
if self.true_positives + self.false_positives > 0:
return self.true_positives / (self.true_positives + self.false_positives)
return 0.0
@property
def recall(self) -> float:
if self.true_positives + self.false_negatives > 0:
return self.true_positives / (self.true_positives + self.false_negatives)
else:
return 0.0
@property
def total(self) -> int:
return self.true_positives + self.false_positives
def collect_expected_diagnostics(path: Path) -> list[Diagnostic]:
diagnostics: list[Diagnostic] = []
for file in path.resolve().rglob("*.py"):
for idx, line in enumerate(file.read_text().splitlines(), 1):
if error := re.search(CONFORMANCE_ERROR_PATTERN, line):
diagnostics.append(
Diagnostic(
check_name="conformance",
description=error.group("description")
or error.group("tag")
or "Missing",
severity="major",
fingerprint=None,
location=Location(
path=file.as_posix(),
positions=Positions(
begin=Position(
line=idx,
column=error.start(),
),
end=Position(
line=idx,
column=error.end(),
),
),
),
source=Source.EXPECTED,
)
)
assert diagnostics, "Failed to discover any expected diagnostics!"
return diagnostics
def collect_ty_diagnostics(
ty_path: list[str],
source: Source,
tests_path: str = ".",
python_version: str = "3.12",
) -> list[Diagnostic]:
process = subprocess.run(
[
*ty_path,
"check",
f"--python-version={python_version}",
"--output-format=gitlab",
"--exit-zero",
tests_path,
],
capture_output=True,
text=True,
check=True,
timeout=15,
)
if process.returncode != 0:
print(process.stderr)
raise RuntimeError(f"ty check failed with exit code {process.returncode}")
return [
Diagnostic.from_gitlab_output(dct, source=source)
for dct in json.loads(process.stdout)
]
def group_diagnostics_by_key(
old: list[Diagnostic], new: list[Diagnostic], expected: list[Diagnostic]
) -> list[GroupedDiagnostics]:
diagnostics = [
*old,
*new,
*expected,
]
sorted_diagnostics = sorted(diagnostics, key=attrgetter("key"))
grouped = []
for key, group in groupby(sorted_diagnostics, key=attrgetter("key")):
group = list(group)
sources: Source = reduce(or_, (diag.source for diag in group))
grouped.append(
GroupedDiagnostics(
key=key,
sources=sources,
old=next(filter(lambda diag: diag.source == Source.OLD, group), None),
new=next(filter(lambda diag: diag.source == Source.NEW, group), None),
expected=next(
filter(lambda diag: diag.source == Source.EXPECTED, group), None
),
)
)
return grouped
def compute_stats(
grouped_diagnostics: list[GroupedDiagnostics], source: Source
) -> Statistics:
if source == source.EXPECTED:
# ty currently raises a false positive here due to incomplete enum.Flag support
# see https://github.com/astral-sh/ty/issues/876
num_errors = sum(
1
for g in grouped_diagnostics
if source.EXPECTED in g.sources # ty:ignore[unsupported-operator]
)
return Statistics(
true_positives=num_errors, false_positives=0, false_negatives=0
)
def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics:
if (source in grouped.sources) and (Source.EXPECTED in grouped.sources):
statistics.true_positives += 1
elif source in grouped.sources:
statistics.false_positives += 1
elif Source.EXPECTED in grouped.sources:
statistics.false_negatives += 1
return statistics
return reduce(increment, grouped_diagnostics, Statistics())
def render_grouped_diagnostics(
grouped: list[GroupedDiagnostics],
*,
changed_only: bool = True,
format: Literal["diff", "github"] = "diff",
) -> str:
if changed_only:
grouped = [diag for diag in grouped if diag.changed]
sorted_by_class = sorted(
grouped,
key=attrgetter("classification"),
reverse=True,
)
match format:
case "diff":
header = ["```diff"]
footer = "```"
case "github":
header = [
"| Location | Name | Message |",
"|----------|------|---------|",
]
footer = ""
case _:
raise ValueError("format must be one of 'diff' or 'github'")
lines = []
for classification, group in groupby(
sorted_by_class, key=attrgetter("classification")
):
group = list(group)
lines.append(f"## {classification.into_title()}")
lines.extend(["", "<details>", ""])
lines.extend(header)
for diag in group:
lines.append(diag.display(format=format))
lines.append(footer)
lines.extend(["", "</details>", ""])
return "\n".join(lines)
def diff_format(
diff: float,
*,
greater_is_better: bool = True,
neutral: bool = False,
is_percentage: bool = False,
):
increased = diff > 0
good = " (✅)" if not neutral else ""
bad = " (❌)" if not neutral else ""
up = ""
down = ""
match (greater_is_better, increased):
case (True, True):
return f"{up}{good}"
case (False, True):
return f"{up}{bad}"
case (True, False):
return f"{down}{bad}"
case (False, False):
return f"{down}{good}"
def render_summary(grouped_diagnostics: list[GroupedDiagnostics]):
def format_metric(diff: float, old: float, new: float):
if diff > 0:
return f"increased from {old:.2%} to {new:.2%}"
if diff < 0:
return f"decreased from {old:.2%} to {new:.2%}"
return f"held steady at {old:.2%}"
old = compute_stats(grouped_diagnostics, source=Source.OLD)
new = compute_stats(grouped_diagnostics, source=Source.NEW)
precision_change = new.precision - old.precision
recall_change = new.recall - old.recall
true_pos_change = new.true_positives - old.true_positives
false_pos_change = new.false_positives - old.false_positives
false_neg_change = new.false_negatives - old.false_negatives
total_change = new.total - old.total
true_pos_diff = diff_format(true_pos_change, greater_is_better=True)
false_pos_diff = diff_format(false_pos_change, greater_is_better=False)
false_neg_diff = diff_format(false_neg_change, greater_is_better=False)
precision_diff = diff_format(
precision_change, greater_is_better=True, is_percentage=True
)
recall_diff = diff_format(recall_change, greater_is_better=True, is_percentage=True)
total_diff = diff_format(total_change, neutral=True)
table = dedent(
f"""
## Typing Conformance
### Summary
| Metric | Old | New | Diff | Outcome |
|--------|-----|-----|------|---------|
| True Positives | {old.true_positives} | {new.true_positives} | {true_pos_change:+} | {true_pos_diff} |
| False Positives | {old.false_positives} | {new.false_positives} | {false_pos_change:+} | {false_pos_diff} |
| False Negatives | {old.false_negatives} | {new.false_negatives} | {false_neg_change:+} | {false_neg_diff} |
| Total Diagnostics | {old.total} | {new.total} | {total_change} | {total_diff} |
| Precision | {old.precision:.2%} | {new.precision:.2%} | {precision_change:+.2%} | {precision_diff} |
| Recall | {old.recall:.2%} | {new.recall:.2%} | {recall_change:+.2%} | {recall_diff} |
"""
)
summary = (
f"The percentage of diagnostics emitted that were expected errors"
f" {format_metric(precision_change, old.precision, new.precision)},"
" and the percentage of expected errors that received a diagnostic"
f" {format_metric(recall_change, old.recall, new.recall)}."
)
return "\n".join([table, summary])
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--old-ty",
nargs="+",
help="Command to run old version of ty (default: uvx ty@0.0.1a35)",
)
parser.add_argument(
"--new-ty",
nargs="+",
default=["uvx", "ty@latest"],
help="Command to run new version of ty (default: uvx ty@0.0.7)",
)
parser.add_argument(
"--tests-path",
type=Path,
default=Path("typing/conformance/tests"),
help="Path to conformance tests directory (default: typing/conformance/tests)",
)
parser.add_argument(
"--python-version",
type=str,
default="3.12",
help="Python version to assume when running ty (default: 3.12)",
)
parser.add_argument(
"--all",
action="store_true",
help="Show all diagnostics, not just changed ones",
)
parser.add_argument(
"--format", type=str, choices=["diff", "github"], default="github"
)
parser.add_argument(
"--output",
type=Path,
help="Write output to file instead of stdout",
)
args = parser.parse_args()
if args.old_ty is None:
raise ValueError("old_ty is required")
return args
def main():
args = parse_args()
expected = collect_expected_diagnostics(args.tests_path)
old = collect_ty_diagnostics(
ty_path=args.old_ty,
tests_path=str(args.tests_path),
source=Source.OLD,
python_version=args.python_version,
)
new = collect_ty_diagnostics(
ty_path=args.new_ty,
tests_path=str(args.tests_path),
source=Source.NEW,
python_version=args.python_version,
)
grouped = group_diagnostics_by_key(
old=old,
new=new,
expected=expected,
)
rendered = "\n\n".join(
[
render_summary(grouped),
render_grouped_diagnostics(
grouped, changed_only=not args.all, format=args.format
),
]
)
if args.output:
args.output.write_text(rendered, encoding="utf-8")
print(f"Output written to {args.output}", file=sys.stderr)
else:
print(rendered)
if __name__ == "__main__":
main()