mirror of
https://github.com/astral-sh/ruff
synced 2026-01-21 13:30:49 -05:00
662 lines
21 KiB
Python
662 lines
21 KiB
Python
"""
|
|
Run typing conformance tests and compare results between two ty versions.
|
|
|
|
By default, this script will use `uv` to run the latest version of ty
|
|
as the new version with `uvx ty@latest`. This requires `uv` to be installed
|
|
and available in the system PATH.
|
|
|
|
If CONFORMANCE_SUITE_COMMIT is set, the hash will be used to create
|
|
links to the corresponding line in the conformance repository for each
|
|
diagnostic. Otherwise, it will default to `main'.
|
|
|
|
Examples:
|
|
# Compare an older version of ty to latest
|
|
%(prog)s --old-ty uvx ty@0.0.1a35
|
|
|
|
# Compare two specific ty versions
|
|
%(prog)s --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7
|
|
|
|
# Use local ty builds
|
|
%(prog)s --old-ty ./target/debug/ty-old --new-ty ./target/debug/ty-new
|
|
|
|
# Custom test directory
|
|
%(prog)s --target-path custom/tests --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7
|
|
|
|
# Show all diagnostics (not just changed ones)
|
|
%(prog)s --all --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7
|
|
|
|
# Show a diff with local paths to the test directory instead of table of links
|
|
%(prog)s --old-ty uvx ty@0.0.1a35 --new-ty uvx ty@0.0.7 --format diff
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from enum import Flag, StrEnum, auto
|
|
from functools import reduce
|
|
from itertools import groupby
|
|
from operator import attrgetter, or_
|
|
from pathlib import Path
|
|
from textwrap import dedent
|
|
from typing import Any, Literal, Self, assert_never
|
|
|
|
# The conformance tests include 4 types of errors:
|
|
# 1. Required errors (E): The type checker must raise an error on this line
|
|
# 2. Optional errors (E?): The type checker may raise an error on this line
|
|
# 3. Tagged errors (E[tag]): The type checker must raise at most one error on any of the lines in a file with matching tags
|
|
# 4. Tagged multi-errors (E[tag+]): The type checker should raise one or more errors on any of the tagged lines
|
|
# This regex pattern parses the error lines in the conformance tests, but the following
|
|
# implementation treats all errors as required errors.
|
|
CONFORMANCE_ERROR_PATTERN = re.compile(
|
|
r"""
|
|
\#\s*E # "# E" begins each error
|
|
(?P<optional>\?)? # Optional '?' (E?) indicates that an error is optional
|
|
(?: # An optional tag for errors that may appear on multiple lines at most once
|
|
\[
|
|
(?P<tag>[^+\]]+) # identifier
|
|
(?P<multi>\+)? # '+' indicates that an error may occur more than once on tagged lines
|
|
\]
|
|
)?
|
|
(?:
|
|
\s*:\s*(?P<description>.*) # optional description
|
|
)?
|
|
""",
|
|
re.VERBOSE,
|
|
)
|
|
|
|
CONFORMANCE_SUITE_COMMIT = os.environ.get("CONFORMANCE_SUITE_COMMIT", "main")
|
|
CONFORMANCE_DIR_WITH_README = (
|
|
f"https://github.com/python/typing/blob/{CONFORMANCE_SUITE_COMMIT}/conformance/"
|
|
)
|
|
CONFORMANCE_URL = CONFORMANCE_DIR_WITH_README + "/tests/{filename}#L{line}"
|
|
|
|
|
|
class Source(Flag):
|
|
OLD = auto()
|
|
NEW = auto()
|
|
EXPECTED = auto()
|
|
|
|
|
|
class Classification(StrEnum):
|
|
TRUE_POSITIVE = auto()
|
|
FALSE_POSITIVE = auto()
|
|
TRUE_NEGATIVE = auto()
|
|
FALSE_NEGATIVE = auto()
|
|
|
|
def into_title(self) -> str:
|
|
match self:
|
|
case Classification.TRUE_POSITIVE:
|
|
return "True positives added"
|
|
case Classification.FALSE_POSITIVE:
|
|
return "False positives added"
|
|
case Classification.TRUE_NEGATIVE:
|
|
return "False positives removed"
|
|
case Classification.FALSE_NEGATIVE:
|
|
return "True positives removed"
|
|
|
|
|
|
@dataclass(kw_only=True, slots=True)
|
|
class Position:
|
|
line: int
|
|
column: int
|
|
|
|
|
|
@dataclass(kw_only=True, slots=True)
|
|
class Positions:
|
|
begin: Position
|
|
end: Position
|
|
|
|
|
|
@dataclass(kw_only=True, slots=True)
|
|
class Location:
|
|
path: Path
|
|
positions: Positions
|
|
|
|
def as_link(self) -> str:
|
|
file = self.path.name
|
|
link = CONFORMANCE_URL.format(
|
|
conformance_suite_commit=CONFORMANCE_SUITE_COMMIT,
|
|
filename=file,
|
|
line=self.positions.begin.line,
|
|
)
|
|
return f"[{file}:{self.positions.begin.line}:{self.positions.begin.column}]({link})"
|
|
|
|
|
|
@dataclass(kw_only=True, slots=True)
|
|
class Diagnostic:
|
|
check_name: str
|
|
description: str
|
|
severity: str
|
|
fingerprint: str | None
|
|
location: Location
|
|
source: Source
|
|
|
|
def __post_init__(self, *args, **kwargs) -> None:
|
|
# Escape pipe characters for GitHub markdown tables
|
|
self.description = self.description.replace("|", "\\|")
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"{self.location.path}:{self.location.positions.begin.line}:"
|
|
f"{self.location.positions.begin.column}: "
|
|
f"{self.severity_for_display}[{self.check_name}] {self.description}"
|
|
)
|
|
|
|
@classmethod
|
|
def from_gitlab_output(
|
|
cls,
|
|
dct: dict[str, Any],
|
|
source: Source,
|
|
) -> Self:
|
|
return cls(
|
|
check_name=dct["check_name"],
|
|
description=dct["description"],
|
|
severity=dct["severity"],
|
|
fingerprint=dct["fingerprint"],
|
|
location=Location(
|
|
path=Path(dct["location"]["path"]).resolve(),
|
|
positions=Positions(
|
|
begin=Position(
|
|
line=dct["location"]["positions"]["begin"]["line"],
|
|
column=dct["location"]["positions"]["begin"]["column"],
|
|
),
|
|
end=Position(
|
|
line=dct["location"]["positions"]["end"]["line"],
|
|
column=dct["location"]["positions"]["end"]["column"],
|
|
),
|
|
),
|
|
),
|
|
source=source,
|
|
)
|
|
|
|
@property
|
|
def key(self) -> str:
|
|
"""Key to group diagnostics by path and beginning line."""
|
|
return f"{self.location.path.as_posix()}:{self.location.positions.begin.line}"
|
|
|
|
@property
|
|
def severity_for_display(self) -> str:
|
|
return {
|
|
"major": "error",
|
|
"minor": "warning",
|
|
}.get(self.severity, "unknown")
|
|
|
|
|
|
@dataclass(kw_only=True, slots=True)
|
|
class GroupedDiagnostics:
|
|
key: str
|
|
sources: Source
|
|
old: Diagnostic | None
|
|
new: Diagnostic | None
|
|
expected: Diagnostic | None
|
|
|
|
@property
|
|
def changed(self) -> bool:
|
|
return (Source.OLD in self.sources or Source.NEW in self.sources) and not (
|
|
Source.OLD in self.sources and Source.NEW in self.sources
|
|
)
|
|
|
|
@property
|
|
def classification(self) -> Classification:
|
|
if Source.NEW in self.sources and Source.EXPECTED in self.sources:
|
|
return Classification.TRUE_POSITIVE
|
|
elif Source.NEW in self.sources and Source.EXPECTED not in self.sources:
|
|
return Classification.FALSE_POSITIVE
|
|
elif Source.EXPECTED in self.sources:
|
|
return Classification.FALSE_NEGATIVE
|
|
else:
|
|
return Classification.TRUE_NEGATIVE
|
|
|
|
def _render_row(self, diagnostic: Diagnostic):
|
|
return f"| {diagnostic.location.as_link()} | {diagnostic.check_name} | {diagnostic.description} |"
|
|
|
|
def _render_diff(self, diagnostic: Diagnostic, *, removed: bool = False):
|
|
sign = "-" if removed else "+"
|
|
return f"{sign} {diagnostic}"
|
|
|
|
def display(self, format: Literal["diff", "github"]) -> str:
|
|
match self.classification:
|
|
case Classification.TRUE_POSITIVE | Classification.FALSE_POSITIVE:
|
|
assert self.new is not None
|
|
return (
|
|
self._render_diff(self.new)
|
|
if format == "diff"
|
|
else self._render_row(self.new)
|
|
)
|
|
|
|
case Classification.FALSE_NEGATIVE | Classification.TRUE_NEGATIVE:
|
|
diagnostic = self.old or self.expected
|
|
assert diagnostic is not None
|
|
return (
|
|
self._render_diff(diagnostic, removed=True)
|
|
if format == "diff"
|
|
else self._render_row(diagnostic)
|
|
)
|
|
|
|
case _:
|
|
raise ValueError(f"Unexpected classification: {self.classification}")
|
|
|
|
|
|
@dataclass(kw_only=True, slots=True)
|
|
class Statistics:
|
|
true_positives: int = 0
|
|
false_positives: int = 0
|
|
false_negatives: int = 0
|
|
|
|
@property
|
|
def precision(self) -> float:
|
|
if self.true_positives + self.false_positives > 0:
|
|
return self.true_positives / (self.true_positives + self.false_positives)
|
|
return 0.0
|
|
|
|
@property
|
|
def recall(self) -> float:
|
|
if self.true_positives + self.false_negatives > 0:
|
|
return self.true_positives / (self.true_positives + self.false_negatives)
|
|
else:
|
|
return 0.0
|
|
|
|
@property
|
|
def total(self) -> int:
|
|
return self.true_positives + self.false_positives
|
|
|
|
|
|
def collect_expected_diagnostics(path: Path) -> list[Diagnostic]:
|
|
diagnostics: list[Diagnostic] = []
|
|
for file in path.resolve().rglob("*.py"):
|
|
for idx, line in enumerate(file.read_text().splitlines(), 1):
|
|
if error := re.search(CONFORMANCE_ERROR_PATTERN, line):
|
|
diagnostics.append(
|
|
Diagnostic(
|
|
check_name="conformance",
|
|
description=(
|
|
error.group("description")
|
|
or error.group("tag")
|
|
or "Missing"
|
|
),
|
|
severity="major",
|
|
fingerprint=None,
|
|
location=Location(
|
|
path=file,
|
|
positions=Positions(
|
|
begin=Position(
|
|
line=idx,
|
|
column=error.start(),
|
|
),
|
|
end=Position(
|
|
line=idx,
|
|
column=error.end(),
|
|
),
|
|
),
|
|
),
|
|
source=Source.EXPECTED,
|
|
)
|
|
)
|
|
|
|
assert diagnostics, "Failed to discover any expected diagnostics!"
|
|
return diagnostics
|
|
|
|
|
|
def collect_ty_diagnostics(
|
|
ty_path: list[str],
|
|
source: Source,
|
|
tests_path: str = ".",
|
|
python_version: str = "3.12",
|
|
) -> list[Diagnostic]:
|
|
process = subprocess.run(
|
|
[
|
|
*ty_path,
|
|
"check",
|
|
f"--python-version={python_version}",
|
|
"--output-format=gitlab",
|
|
"--exit-zero",
|
|
tests_path,
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
timeout=15,
|
|
)
|
|
|
|
if process.returncode != 0:
|
|
print(process.stderr)
|
|
raise RuntimeError(f"ty check failed with exit code {process.returncode}")
|
|
|
|
return [
|
|
Diagnostic.from_gitlab_output(dct, source=source)
|
|
for dct in json.loads(process.stdout)
|
|
]
|
|
|
|
|
|
def group_diagnostics_by_key(
|
|
old: list[Diagnostic], new: list[Diagnostic], expected: list[Diagnostic]
|
|
) -> list[GroupedDiagnostics]:
|
|
diagnostics = [
|
|
*old,
|
|
*new,
|
|
*expected,
|
|
]
|
|
sorted_diagnostics = sorted(diagnostics, key=attrgetter("key"))
|
|
|
|
grouped = []
|
|
for key, group in groupby(sorted_diagnostics, key=attrgetter("key")):
|
|
group = list(group)
|
|
sources: Source = reduce(or_, (diag.source for diag in group))
|
|
grouped.append(
|
|
GroupedDiagnostics(
|
|
key=key,
|
|
sources=sources,
|
|
old=next(filter(lambda diag: diag.source == Source.OLD, group), None),
|
|
new=next(filter(lambda diag: diag.source == Source.NEW, group), None),
|
|
expected=next(
|
|
filter(lambda diag: diag.source == Source.EXPECTED, group), None
|
|
),
|
|
)
|
|
)
|
|
|
|
return grouped
|
|
|
|
|
|
def compute_stats(
|
|
grouped_diagnostics: list[GroupedDiagnostics], source: Source
|
|
) -> Statistics:
|
|
if source == source.EXPECTED:
|
|
# ty currently raises a false positive here due to incomplete enum.Flag support
|
|
# see https://github.com/astral-sh/ty/issues/876
|
|
num_errors = sum(
|
|
1
|
|
for g in grouped_diagnostics
|
|
if source.EXPECTED in g.sources # ty:ignore[unsupported-operator]
|
|
)
|
|
return Statistics(
|
|
true_positives=num_errors, false_positives=0, false_negatives=0
|
|
)
|
|
|
|
def increment(statistics: Statistics, grouped: GroupedDiagnostics) -> Statistics:
|
|
if (source in grouped.sources) and (Source.EXPECTED in grouped.sources):
|
|
statistics.true_positives += 1
|
|
elif source in grouped.sources:
|
|
statistics.false_positives += 1
|
|
elif Source.EXPECTED in grouped.sources:
|
|
statistics.false_negatives += 1
|
|
return statistics
|
|
|
|
return reduce(increment, grouped_diagnostics, Statistics())
|
|
|
|
|
|
def render_grouped_diagnostics(
|
|
grouped: list[GroupedDiagnostics],
|
|
*,
|
|
changed_only: bool = True,
|
|
format: Literal["diff", "github"] = "diff",
|
|
) -> str:
|
|
if changed_only:
|
|
grouped = [diag for diag in grouped if diag.changed]
|
|
|
|
sorted_by_class = sorted(
|
|
grouped,
|
|
key=attrgetter("classification"),
|
|
reverse=True,
|
|
)
|
|
|
|
match format:
|
|
case "diff":
|
|
header = ["```diff"]
|
|
footer = "```"
|
|
case "github":
|
|
header = [
|
|
"| Location | Name | Message |",
|
|
"|----------|------|---------|",
|
|
]
|
|
footer = ""
|
|
case _:
|
|
raise ValueError("format must be one of 'diff' or 'github'")
|
|
|
|
lines = []
|
|
for classification, group in groupby(
|
|
sorted_by_class, key=attrgetter("classification")
|
|
):
|
|
group = list(group)
|
|
|
|
lines.append(f"### {classification.into_title()}")
|
|
lines.extend(["", "<details>", ""])
|
|
|
|
lines.extend(header)
|
|
|
|
for diag in group:
|
|
lines.append(diag.display(format=format))
|
|
|
|
lines.append(footer)
|
|
lines.extend(["", "</details>", ""])
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def diff_format(
|
|
diff: float,
|
|
*,
|
|
greater_is_better: bool = True,
|
|
neutral: bool = False,
|
|
) -> str:
|
|
if diff == 0:
|
|
return ""
|
|
|
|
increased = diff > 0
|
|
good = " (✅)" if not neutral else ""
|
|
bad = " (❌)" if not neutral else ""
|
|
up = "⏫"
|
|
down = "⏬"
|
|
|
|
match (greater_is_better, increased):
|
|
case (True, True):
|
|
return f"{up}{good}"
|
|
case (False, True):
|
|
return f"{up}{bad}"
|
|
case (True, False):
|
|
return f"{down}{bad}"
|
|
case (False, False):
|
|
return f"{down}{good}"
|
|
case _:
|
|
# The ty false positive seems to be due to insufficient type narrowing for tuples;
|
|
# possibly related to https://github.com/astral-sh/ty/issues/493 and/or
|
|
# https://github.com/astral-sh/ty/issues/887
|
|
assert_never((greater_is_better, increased)) # ty: ignore[type-assertion-failure]
|
|
|
|
|
|
def render_summary(grouped_diagnostics: list[GroupedDiagnostics]):
|
|
def format_metric(diff: float, old: float, new: float):
|
|
if diff > 0:
|
|
return f"increased from {old:.2%} to {new:.2%}"
|
|
if diff < 0:
|
|
return f"decreased from {old:.2%} to {new:.2%}"
|
|
return f"held steady at {old:.2%}"
|
|
|
|
old = compute_stats(grouped_diagnostics, source=Source.OLD)
|
|
new = compute_stats(grouped_diagnostics, source=Source.NEW)
|
|
|
|
assert new.true_positives > 0, (
|
|
"Expected ty to have at least one true positive "
|
|
f"Sample of grouped diagnostics: {grouped_diagnostics[:5]}"
|
|
)
|
|
|
|
precision_change = new.precision - old.precision
|
|
recall_change = new.recall - old.recall
|
|
true_pos_change = new.true_positives - old.true_positives
|
|
false_pos_change = new.false_positives - old.false_positives
|
|
false_neg_change = new.false_negatives - old.false_negatives
|
|
total_change = new.total - old.total
|
|
|
|
base_header = f"[Typing conformance results]({CONFORMANCE_DIR_WITH_README})"
|
|
|
|
if (
|
|
precision_change == 0
|
|
and recall_change == 0
|
|
and true_pos_change == 0
|
|
and false_pos_change == 0
|
|
and false_neg_change == 0
|
|
and total_change == 0
|
|
):
|
|
return dedent(
|
|
f"""
|
|
## {base_header}
|
|
|
|
No changes detected ✅
|
|
"""
|
|
)
|
|
|
|
true_pos_diff = diff_format(true_pos_change, greater_is_better=True)
|
|
false_pos_diff = diff_format(false_pos_change, greater_is_better=False)
|
|
false_neg_diff = diff_format(false_neg_change, greater_is_better=False)
|
|
precision_diff = diff_format(precision_change, greater_is_better=True)
|
|
recall_diff = diff_format(recall_change, greater_is_better=True)
|
|
total_diff = diff_format(total_change, neutral=True)
|
|
|
|
if (precision_change > 0 and recall_change >= 0) or (
|
|
recall_change > 0 and precision_change >= 0
|
|
):
|
|
header = f"{base_header} improved 🎉"
|
|
elif (precision_change < 0 and recall_change <= 0) or (
|
|
recall_change < 0 and precision_change <= 0
|
|
):
|
|
header = f"{base_header} regressed ❌"
|
|
else:
|
|
header = base_header
|
|
|
|
summary_paragraph = (
|
|
f"The percentage of diagnostics emitted that were expected errors "
|
|
f"{format_metric(precision_change, old.precision, new.precision)}. "
|
|
f"The percentage of expected errors that received a diagnostic "
|
|
f"{format_metric(recall_change, old.recall, new.recall)}."
|
|
)
|
|
|
|
return dedent(
|
|
f"""
|
|
## {header}
|
|
|
|
{summary_paragraph}
|
|
|
|
### Summary
|
|
|
|
| Metric | Old | New | Diff | Outcome |
|
|
|--------|-----|-----|------|---------|
|
|
| True Positives | {old.true_positives} | {new.true_positives} | {true_pos_change:+} | {true_pos_diff} |
|
|
| False Positives | {old.false_positives} | {new.false_positives} | {false_pos_change:+} | {false_pos_diff} |
|
|
| False Negatives | {old.false_negatives} | {new.false_negatives} | {false_neg_change:+} | {false_neg_diff} |
|
|
| Total Diagnostics | {old.total} | {new.total} | {total_change:+} | {total_diff} |
|
|
| Precision | {old.precision:.2%} | {new.precision:.2%} | {precision_change:+.2%} | {precision_diff} |
|
|
| Recall | {old.recall:.2%} | {new.recall:.2%} | {recall_change:+.2%} | {recall_diff} |
|
|
|
|
"""
|
|
)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--old-ty",
|
|
nargs="+",
|
|
help="Command to run old version of ty",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--new-ty",
|
|
nargs="+",
|
|
default=["uvx", "ty@latest"],
|
|
help="Command to run new version of ty (default: uvx ty@latest)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--tests-path",
|
|
type=Path,
|
|
default=Path("typing/conformance/tests"),
|
|
help="Path to conformance tests directory (default: typing/conformance/tests)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--python-version",
|
|
type=str,
|
|
default="3.12",
|
|
help="Python version to assume when running ty (default: 3.12)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--all",
|
|
action="store_true",
|
|
help="Show all diagnostics, not just changed ones",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--format", type=str, choices=["diff", "github"], default="github"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output",
|
|
type=Path,
|
|
help="Write output to file instead of stdout",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.old_ty is None:
|
|
raise ValueError("old_ty is required")
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
tests_path = args.tests_path.resolve().absolute()
|
|
|
|
expected = collect_expected_diagnostics(tests_path)
|
|
|
|
old = collect_ty_diagnostics(
|
|
ty_path=args.old_ty,
|
|
tests_path=str(tests_path),
|
|
source=Source.OLD,
|
|
python_version=args.python_version,
|
|
)
|
|
|
|
new = collect_ty_diagnostics(
|
|
ty_path=args.new_ty,
|
|
tests_path=str(tests_path),
|
|
source=Source.NEW,
|
|
python_version=args.python_version,
|
|
)
|
|
|
|
grouped = group_diagnostics_by_key(
|
|
old=old,
|
|
new=new,
|
|
expected=expected,
|
|
)
|
|
|
|
rendered = "\n\n".join(
|
|
[
|
|
render_summary(grouped),
|
|
render_grouped_diagnostics(
|
|
grouped, changed_only=not args.all, format=args.format
|
|
),
|
|
]
|
|
)
|
|
|
|
if args.output:
|
|
args.output.write_text(rendered, encoding="utf-8")
|
|
print(f"Output written to {args.output}", file=sys.stderr)
|
|
print(rendered, file=sys.stderr)
|
|
else:
|
|
print(rendered)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|