ruff/scripts/ty_benchmark/src/benchmark/lsp_client.py

169 lines
5.7 KiB
Python

"""Simple LSP client for benchmarking diagnostic response times."""
import asyncio
import logging
from asyncio import Future
from pathlib import Path
from typing import Any, NamedTuple, override
from lsprotocol import types as lsp
from pygls.lsp.client import LanguageClient
def _register_notebook_structure_hooks(converter):
"""Register structure hooks for notebook document types to work around cattrs deserialization issues."""
# Define a union type that cattrs struggles with.
notebook_filter_union = (
str
| lsp.NotebookDocumentFilterNotebookType
| lsp.NotebookDocumentFilterScheme
| lsp.NotebookDocumentFilterPattern
| None
)
def structure_notebook_filter(obj: Any, _type):
"""Structure a notebook filter field from various possible types."""
if obj is None:
return None
if isinstance(obj, str):
return obj
if isinstance(obj, dict):
# Try to structure it as one of the known types.
if "notebookType" in obj:
return converter.structure(obj, lsp.NotebookDocumentFilterNotebookType)
elif "scheme" in obj:
return converter.structure(obj, lsp.NotebookDocumentFilterScheme)
elif "pattern" in obj:
return converter.structure(obj, lsp.NotebookDocumentFilterPattern)
return obj
converter.register_structure_hook(notebook_filter_union, structure_notebook_filter)
class LSPClient(LanguageClient):
"""A minimal LSP client for benchmarking purposes."""
server_capabilities: lsp.ServerCapabilities
diagnostics: dict[str, Future[lsp.PublishDiagnosticsParams]]
def __init__(
self,
):
super().__init__(
"ty_benchmark",
"v1",
)
# Register custom structure hooks to work around lsprotocol/cattrs issues.
_register_notebook_structure_hooks(self.protocol._converter)
self.diagnostics = {}
@self.feature(lsp.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
def publish_diagnostics(
client: LSPClient, params: lsp.PublishDiagnosticsParams
):
logging.info(
f"Received publish_diagnostics for {params.uri} with version={params.version}, diagnostics count={len(params.diagnostics)}"
)
future = self.diagnostics.get(params.uri, None)
if future is None or future.done():
future = asyncio.Future()
self.diagnostics[params.uri] = future
future.set_result(params)
@self.feature(lsp.WINDOW_LOG_MESSAGE)
def log_message(client: LSPClient, params: lsp.LogMessageParams):
if params.type == lsp.MessageType.Error:
logging.error(f"server error: {params.message}")
elif params.type == lsp.MessageType.Warning:
logging.warning(f"server warning: {params.message}")
else:
logging.info(f"server info: {params.message}")
@override
async def initialize_async(
self, params: lsp.InitializeParams
) -> lsp.InitializeResult:
result = await super().initialize_async(params)
self.server_capabilities = result.capabilities
logging.info(
f"Pull diagnostic support: {self.server_supports_pull_diagnostics}"
)
return result
def clear_pending_publish_diagnostics(self):
self.diagnostics.clear()
@property
def server_supports_pull_diagnostics(self) -> bool:
diagnostic_provider = self.server_capabilities.diagnostic_provider
return diagnostic_provider is not None
async def text_document_diagnostics_async(self, path: Path) -> list[lsp.Diagnostic]:
"""
Returns the diagnostics for `path`
Uses pull diagnostics if the server supports it or waits for a publish diagnostics
notification if not.
"""
if self.server_supports_pull_diagnostics:
pull_diagnostics = await self.text_document_diagnostic_async(
lsp.DocumentDiagnosticParams(
text_document=lsp.TextDocumentIdentifier(uri=path.as_uri())
)
)
assert isinstance(
pull_diagnostics, lsp.RelatedFullDocumentDiagnosticReport
), "Expected a full diagnostic report"
return list(pull_diagnostics.items)
# Use publish diagnostics otherwise.
# Pyrefly doesn't support pull diagnostics as of today (27th of November 2025)
publish_diagnostics = await self.wait_for_push_diagnostics_async(path)
return list(publish_diagnostics.diagnostics)
async def text_documents_diagnostics_async(
self, files: list[Path]
) -> list[FileDiagnostics]:
responses = await asyncio.gather(
*(self.text_document_diagnostics_async(f) for f in files)
)
return [
FileDiagnostics(file, diagnostics=list(response))
for file, response in zip(files, responses)
]
async def wait_for_push_diagnostics_async(
self, path: Path, timeout: float = 60
) -> lsp.PublishDiagnosticsParams:
future = self.diagnostics.get(path.as_uri(), None)
if future is None:
future = asyncio.Future()
self.diagnostics[path.as_uri()] = future
try:
logging.info(f"Waiting for push diagnostics for {path}")
result = await asyncio.wait_for(future, timeout)
logging.info(f"Awaited push diagnostics for {path}")
finally:
self.diagnostics.pop(path.as_uri())
return result
class FileDiagnostics(NamedTuple):
file: Path
diagnostics: list[lsp.Diagnostic]