mirror of https://github.com/astral-sh/ruff
169 lines
5.7 KiB
Python
169 lines
5.7 KiB
Python
"""Simple LSP client for benchmarking diagnostic response times."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from asyncio import Future
|
|
from pathlib import Path
|
|
from typing import Any, NamedTuple, override
|
|
|
|
from lsprotocol import types as lsp
|
|
from pygls.lsp.client import LanguageClient
|
|
|
|
|
|
def _register_notebook_structure_hooks(converter):
|
|
"""Register structure hooks for notebook document types to work around cattrs deserialization issues."""
|
|
|
|
# Define a union type that cattrs struggles with.
|
|
notebook_filter_union = (
|
|
str
|
|
| lsp.NotebookDocumentFilterNotebookType
|
|
| lsp.NotebookDocumentFilterScheme
|
|
| lsp.NotebookDocumentFilterPattern
|
|
| None
|
|
)
|
|
|
|
def structure_notebook_filter(obj: Any, _type):
|
|
"""Structure a notebook filter field from various possible types."""
|
|
if obj is None:
|
|
return None
|
|
if isinstance(obj, str):
|
|
return obj
|
|
if isinstance(obj, dict):
|
|
# Try to structure it as one of the known types.
|
|
if "notebookType" in obj:
|
|
return converter.structure(obj, lsp.NotebookDocumentFilterNotebookType)
|
|
elif "scheme" in obj:
|
|
return converter.structure(obj, lsp.NotebookDocumentFilterScheme)
|
|
elif "pattern" in obj:
|
|
return converter.structure(obj, lsp.NotebookDocumentFilterPattern)
|
|
return obj
|
|
|
|
converter.register_structure_hook(notebook_filter_union, structure_notebook_filter)
|
|
|
|
|
|
class LSPClient(LanguageClient):
|
|
"""A minimal LSP client for benchmarking purposes."""
|
|
|
|
server_capabilities: lsp.ServerCapabilities
|
|
diagnostics: dict[str, Future[lsp.PublishDiagnosticsParams]]
|
|
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__(
|
|
"ty_benchmark",
|
|
"v1",
|
|
)
|
|
|
|
# Register custom structure hooks to work around lsprotocol/cattrs issues.
|
|
_register_notebook_structure_hooks(self.protocol._converter)
|
|
|
|
self.diagnostics = {}
|
|
|
|
@self.feature(lsp.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
|
|
def publish_diagnostics(
|
|
client: LSPClient, params: lsp.PublishDiagnosticsParams
|
|
):
|
|
logging.info(
|
|
f"Received publish_diagnostics for {params.uri} with version={params.version}, diagnostics count={len(params.diagnostics)}"
|
|
)
|
|
future = self.diagnostics.get(params.uri, None)
|
|
|
|
if future is None or future.done():
|
|
future = asyncio.Future()
|
|
self.diagnostics[params.uri] = future
|
|
|
|
future.set_result(params)
|
|
|
|
@self.feature(lsp.WINDOW_LOG_MESSAGE)
|
|
def log_message(client: LSPClient, params: lsp.LogMessageParams):
|
|
if params.type == lsp.MessageType.Error:
|
|
logging.error(f"server error: {params.message}")
|
|
elif params.type == lsp.MessageType.Warning:
|
|
logging.warning(f"server warning: {params.message}")
|
|
else:
|
|
logging.info(f"server info: {params.message}")
|
|
|
|
@override
|
|
async def initialize_async(
|
|
self, params: lsp.InitializeParams
|
|
) -> lsp.InitializeResult:
|
|
result = await super().initialize_async(params)
|
|
|
|
self.server_capabilities = result.capabilities
|
|
|
|
logging.info(
|
|
f"Pull diagnostic support: {self.server_supports_pull_diagnostics}"
|
|
)
|
|
|
|
return result
|
|
|
|
def clear_pending_publish_diagnostics(self):
|
|
self.diagnostics.clear()
|
|
|
|
@property
|
|
def server_supports_pull_diagnostics(self) -> bool:
|
|
diagnostic_provider = self.server_capabilities.diagnostic_provider
|
|
|
|
return diagnostic_provider is not None
|
|
|
|
async def text_document_diagnostics_async(self, path: Path) -> list[lsp.Diagnostic]:
|
|
"""
|
|
Returns the diagnostics for `path`
|
|
|
|
Uses pull diagnostics if the server supports it or waits for a publish diagnostics
|
|
notification if not.
|
|
"""
|
|
if self.server_supports_pull_diagnostics:
|
|
pull_diagnostics = await self.text_document_diagnostic_async(
|
|
lsp.DocumentDiagnosticParams(
|
|
text_document=lsp.TextDocumentIdentifier(uri=path.as_uri())
|
|
)
|
|
)
|
|
|
|
assert isinstance(
|
|
pull_diagnostics, lsp.RelatedFullDocumentDiagnosticReport
|
|
), "Expected a full diagnostic report"
|
|
|
|
return list(pull_diagnostics.items)
|
|
|
|
# Use publish diagnostics otherwise.
|
|
# Pyrefly doesn't support pull diagnostics as of today (27th of November 2025)
|
|
publish_diagnostics = await self.wait_for_push_diagnostics_async(path)
|
|
return list(publish_diagnostics.diagnostics)
|
|
|
|
async def text_documents_diagnostics_async(
|
|
self, files: list[Path]
|
|
) -> list[FileDiagnostics]:
|
|
responses = await asyncio.gather(
|
|
*(self.text_document_diagnostics_async(f) for f in files)
|
|
)
|
|
|
|
return [
|
|
FileDiagnostics(file, diagnostics=list(response))
|
|
for file, response in zip(files, responses)
|
|
]
|
|
|
|
async def wait_for_push_diagnostics_async(
|
|
self, path: Path, timeout: float = 60
|
|
) -> lsp.PublishDiagnosticsParams:
|
|
future = self.diagnostics.get(path.as_uri(), None)
|
|
|
|
if future is None:
|
|
future = asyncio.Future()
|
|
self.diagnostics[path.as_uri()] = future
|
|
|
|
try:
|
|
logging.info(f"Waiting for push diagnostics for {path}")
|
|
result = await asyncio.wait_for(future, timeout)
|
|
logging.info(f"Awaited push diagnostics for {path}")
|
|
finally:
|
|
self.diagnostics.pop(path.as_uri())
|
|
|
|
return result
|
|
|
|
|
|
class FileDiagnostics(NamedTuple):
|
|
file: Path
|
|
diagnostics: list[lsp.Diagnostic]
|