From c12b4592baf78a10bbe3efb9358ecc48c1a23640 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Thu, 5 Oct 2023 18:55:57 +0530 Subject: [PATCH] Add support for Jupyter Notebook --- ruff_lsp/server.py | 565 ++++++++++++++++++++++++++++++++++--------- ruff_lsp/settings.py | 19 +- tests/test_format.py | 15 +- tests/test_server.py | 6 +- 4 files changed, 484 insertions(+), 121 deletions(-) diff --git a/ruff_lsp/server.py b/ruff_lsp/server.py index 827c4de..c58a25c 100755 --- a/ruff_lsp/server.py +++ b/ruff_lsp/server.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import enum import json import logging import os @@ -10,13 +11,19 @@ import shutil import sys import sysconfig +from collections.abc import Iterable, Mapping +from dataclasses import dataclass from pathlib import Path -from typing import NamedTuple, Sequence, cast +from typing import Any, NamedTuple, Sequence, cast from lsprotocol import validators from lsprotocol.types import ( CODE_ACTION_RESOLVE, INITIALIZE, + NOTEBOOK_DOCUMENT_DID_CHANGE, + NOTEBOOK_DOCUMENT_DID_CLOSE, + NOTEBOOK_DOCUMENT_DID_OPEN, + NOTEBOOK_DOCUMENT_DID_SAVE, TEXT_DOCUMENT_CODE_ACTION, TEXT_DOCUMENT_DID_CHANGE, TEXT_DOCUMENT_DID_CLOSE, @@ -34,9 +41,13 @@ Diagnostic, DiagnosticSeverity, DiagnosticTag, + DidChangeNotebookDocumentParams, DidChangeTextDocumentParams, + DidCloseNotebookDocumentParams, DidCloseTextDocumentParams, + DidOpenNotebookDocumentParams, DidOpenTextDocumentParams, + DidSaveNotebookDocumentParams, DidSaveTextDocumentParams, DocumentFormattingParams, Hover, @@ -45,6 +56,11 @@ MarkupContent, MarkupKind, MessageType, + NotebookCellKind, + NotebookDocument, + NotebookDocumentSyncOptions, + NotebookDocumentSyncOptionsNotebookSelectorType2, + NotebookDocumentSyncOptionsNotebookSelectorType2CellsType, OptionalVersionedTextDocumentIdentifier, Position, Range, @@ -54,10 +70,11 @@ ) from packaging.specifiers import SpecifierSet, Version from pygls import server, uris, workspace -from typing_extensions import TypedDict +from typing_extensions import Literal, Self, TypedDict, assert_never from ruff_lsp import __version__, utils from ruff_lsp.settings import ( + Run, UserSettings, WorkspaceSettings, lint_args, @@ -105,6 +122,18 @@ class VersionModified(NamedTuple): name="Ruff", version=__version__, max_workers=MAX_WORKERS, + notebook_document_sync=NotebookDocumentSyncOptions( + notebook_selector=[ + NotebookDocumentSyncOptionsNotebookSelectorType2( + cells=[ + NotebookDocumentSyncOptionsNotebookSelectorType2CellsType( + language="python" + ) + ] + ) + ], + save=True, + ), ) TOOL_MODULE = "ruff.exe" if sys.platform == "win32" else "ruff" @@ -178,6 +207,149 @@ class VersionModified(NamedTuple): ] +### +# Document +### + + +@enum.unique +class DocumentKind(enum.Enum): + """The kind of document.""" + + Text = enum.auto() + """A Python file or a cell in a Notebook Document.""" + + Notebook = enum.auto() + """A Notebook Document.""" + + +@dataclass(frozen=True) +class Document: + """A document representing either a Python file, a Notebook cell, or a Notebook.""" + + uri: str + path: str + source: str + kind: DocumentKind + version: int | None + + @classmethod + def from_text_document(cls, text_document: workspace.TextDocument) -> Self: + """Create a `Document` from the given Text Document.""" + return cls( + uri=text_document.uri, + path=text_document.path, + kind=DocumentKind.Text, + source=text_document.source, + version=text_document.version, + ) + + @classmethod + def from_notebook_document(cls, notebook_document: NotebookDocument) -> Self: + """Create a `Document` from the given Notebook Document.""" + path = uris.to_fs_path(notebook_document.uri) + if path is None: + # `pygls` raises a `Exception` as well in `workspace.TextDocument`. + raise ValueError( + f"Unable to convert URI to file path: {notebook_document.uri}" + ) + return cls( + uri=notebook_document.uri, + path=path, + kind=DocumentKind.Notebook, + source=_create_notebook_json(notebook_document), + version=notebook_document.version, + ) + + @classmethod + def from_uri(cls, uri: str) -> Self: + """Create a `Document` from the given URI. + + The URI can be a file URI, a notebook URI, or a cell URI. The function will + try to get the notebook document first, and if that fails, it will fallback + to the text document. + """ + notebook_document = LSP_SERVER.workspace.get_notebook_document(cell_uri=uri) + if notebook_document is None: + notebook_document = LSP_SERVER.workspace.get_notebook_document( + notebook_uri=uri + ) + if notebook_document is None: + text_document = LSP_SERVER.workspace.get_text_document(uri) + return cls.from_text_document(text_document) + else: + return cls.from_notebook_document(notebook_document) + + def is_stdlib_file(self) -> bool: + """Return True if the document belongs to standard library.""" + return utils.is_stdlib_file(self.path) + + def is_notebook_file(self) -> bool: + """Return True if the document belongs to a Notebook or a cell in a Notebook.""" + return self.kind is DocumentKind.Notebook or self.path.endswith(".ipynb") + + +SourceValue = str | list[str] + + +class CodeCell(TypedDict): + """A code cell in a Jupyter notebook.""" + + cell_type: Literal["code"] + metadata: Any + outputs: list[Any] + source: SourceValue + + +class MarkdownCell(TypedDict): + """A markdown cell in a Jupyter notebook.""" + + cell_type: Literal["markdown"] + metadata: Any + source: SourceValue + + +class Notebook(TypedDict): + """The JSON representation of a Notebook Document.""" + + metadata: Any + nbformat: int + nbformat_minor: int + cells: list[CodeCell | MarkdownCell] + + +def _create_notebook_json(notebook_document: NotebookDocument) -> str: + """Create a JSON representation of the given Notebook Document.""" + cells: list[CodeCell | MarkdownCell] = [] + for notebook_cell in notebook_document.cells: + cell_document = LSP_SERVER.workspace.get_text_document(notebook_cell.document) + if notebook_cell.kind is NotebookCellKind.Code: + cells.append( + { + "cell_type": "code", + "metadata": None, + "outputs": [], + "source": cell_document.source, + } + ) + else: + cells.append( + { + "cell_type": "markdown", + "metadata": None, + "source": cell_document.source, + } + ) + return json.dumps( + { + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5, + "cells": cells, + } + ) + + ### # Linting. ### @@ -186,38 +358,137 @@ class VersionModified(NamedTuple): @LSP_SERVER.feature(TEXT_DOCUMENT_DID_OPEN) async def did_open(params: DidOpenTextDocumentParams) -> None: """LSP handler for textDocument/didOpen request.""" - document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) - diagnostics: list[Diagnostic] = await _lint_document_impl(document) + document = Document.from_text_document( + LSP_SERVER.workspace.get_text_document(params.text_document.uri) + ) + diagnostics = await _lint_document_impl(document) LSP_SERVER.publish_diagnostics(document.uri, diagnostics) @LSP_SERVER.feature(TEXT_DOCUMENT_DID_CLOSE) def did_close(params: DidCloseTextDocumentParams) -> None: """LSP handler for textDocument/didClose request.""" - document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) + text_document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) # Publishing empty diagnostics to clear the entries for this file. - LSP_SERVER.publish_diagnostics(document.uri, []) + LSP_SERVER.publish_diagnostics(text_document.uri, []) @LSP_SERVER.feature(TEXT_DOCUMENT_DID_SAVE) async def did_save(params: DidSaveTextDocumentParams) -> None: """LSP handler for textDocument/didSave request.""" - document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) - if lint_run(_get_settings_by_document(document.path)) in ("onType", "onSave"): - diagnostics: list[Diagnostic] = await _lint_document_impl(document) + text_document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) + if lint_run(_get_settings_by_document(text_document.path)) in ( + Run.OnType, + Run.OnSave, + ): + document = Document.from_text_document(text_document) + diagnostics = await _lint_document_impl(document) LSP_SERVER.publish_diagnostics(document.uri, diagnostics) @LSP_SERVER.feature(TEXT_DOCUMENT_DID_CHANGE) async def did_change(params: DidChangeTextDocumentParams) -> None: """LSP handler for textDocument/didChange request.""" - document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) - if lint_run(_get_settings_by_document(document.path)) == "onType": - diagnostics: list[Diagnostic] = await _lint_document_impl(document) + text_document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) + if lint_run(_get_settings_by_document(text_document.path)) == Run.OnType: + document = Document.from_text_document(text_document) + diagnostics = await _lint_document_impl(document) LSP_SERVER.publish_diagnostics(document.uri, diagnostics) -async def _lint_document_impl(document: workspace.TextDocument) -> list[Diagnostic]: +@LSP_SERVER.feature(NOTEBOOK_DOCUMENT_DID_OPEN) +async def did_open_notebook(params: DidOpenNotebookDocumentParams) -> None: + """LSP handler for notebookDocument/didOpen request.""" + notebook_document = LSP_SERVER.workspace.get_notebook_document( + notebook_uri=params.notebook_document.uri + ) + if notebook_document is None: + log_warning(f"No notebook document found for {params.notebook_document.uri!r}") + return + + document = Document.from_notebook_document(notebook_document) + diagnostics = await _lint_document_impl(document) + + # Publish diagnostics for each cell. + for cell_idx, diagnostics in _group_diagnostics_by_cell(diagnostics).items(): + LSP_SERVER.publish_diagnostics( + # The cell indices are 1-based in Ruff. + params.notebook_document.cells[cell_idx - 1].document, + diagnostics, + ) + + +@LSP_SERVER.feature(NOTEBOOK_DOCUMENT_DID_CLOSE) +def did_close_notebook(params: DidCloseNotebookDocumentParams) -> None: + """LSP handler for notebookDocument/didClose request.""" + # Publishing empty diagnostics to clear the entries for all the cells in this + # Notebook Document. + for cell_text_document in params.cell_text_documents: + LSP_SERVER.publish_diagnostics(cell_text_document.uri, []) + + +@LSP_SERVER.feature(NOTEBOOK_DOCUMENT_DID_SAVE) +async def did_save_notebook(params: DidSaveNotebookDocumentParams) -> None: + """LSP handler for notebookDocument/didSave request.""" + await _did_change_or_save_notebook( + params.notebook_document.uri, run_types=[Run.OnSave, Run.OnType] + ) + + +@LSP_SERVER.feature(NOTEBOOK_DOCUMENT_DID_CHANGE) +async def did_change_notebook(params: DidChangeNotebookDocumentParams) -> None: + """LSP handler for notebookDocument/didChange request.""" + await _did_change_or_save_notebook( + params.notebook_document.uri, run_types=[Run.OnType] + ) + + +def _group_diagnostics_by_cell( + diagnostics: Iterable[Diagnostic], +) -> Mapping[int, list[Diagnostic]]: + """Group diagnostics by cell index. + + The function will return a mapping from cell number to a list of diagnostics for + that cell. The mapping will be empty if the diagnostic doesn't contain the cell + information. + """ + cell_diagnostics: dict[int, list[Diagnostic]] = {} + for diagnostic in diagnostics: + cell = cast(DiagnosticData, diagnostic.data).get("cell") + if cell is not None: + cell_diagnostics.setdefault(cell, []).append(diagnostic) + return cell_diagnostics + + +async def _did_change_or_save_notebook( + notebook_uri: str, *, run_types: Sequence[Run] +) -> None: + """Handle notebookDocument/didChange and notebookDocument/didSave requests.""" + notebook_document = LSP_SERVER.workspace.get_notebook_document( + notebook_uri=notebook_uri + ) + if notebook_document is None: + log_warning(f"No notebook document found for {notebook_uri!r}") + return + + document = Document.from_notebook_document(notebook_document) + if lint_run(_get_settings_by_document(document.path)) not in run_types: + return + + cell_diagnostics = _group_diagnostics_by_cell(await _lint_document_impl(document)) + + # Publish diagnostics for every code cell, replacing the previous diagnostics. + for cell_idx, cell in enumerate(notebook_document.cells): + if cell.kind is not NotebookCellKind.Code: + continue + LSP_SERVER.publish_diagnostics( + cell.document, + # The cell indices are 1-based in Ruff. + cell_diagnostics.get(cell_idx + 1, []), + ) + + +async def _lint_document_impl(document: Document) -> list[Diagnostic]: result = await _run_check_on_document(document) if result is None: return [] @@ -263,6 +534,7 @@ def _parse_output(content: bytes) -> list[Diagnostic]: # Ruff's output looks like: # [ # { + # "cell": null, # "code": "F841", # "message": "Local variable `x` is assigned to but never used", # "location": { @@ -302,6 +574,9 @@ def _parse_output(content: bytes) -> list[Diagnostic]: # x = 0 # print() # ``` + # + # Cell represents the cell number in a Notebook Document. It is null for normal + # Python files. for check in json.loads(content): start = Position( line=max([int(check["location"]["row"]) - 1, 0]), @@ -322,6 +597,8 @@ def _parse_output(content: bytes) -> list[Diagnostic]: fix=_parse_fix(check.get("fix")), # Available since Ruff v0.0.253. noqa_row=check.get("noqa_row"), + # Available since Ruff v0.1.0. + cell=check.get("cell"), ), tags=_get_tags(check["code"]), ) @@ -366,7 +643,11 @@ def _get_severity(code: str) -> DiagnosticSeverity: @LSP_SERVER.feature(TEXT_DOCUMENT_HOVER) async def hover(params: HoverParams) -> Hover | None: - """LSP handler for textDocument/hover request.""" + """LSP handler for textDocument/hover request. + + This works for both Python files and Notebook Documents. For Notebook Documents, + the hover works at the cell level. + """ document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) match = NOQA_REGEX.search(document.lines[params.position.line]) if not match: @@ -429,6 +710,7 @@ class Fix(TypedDict): class DiagnosticData(TypedDict, total=False): fix: Fix | None noqa_row: int | None + cell: int | None class LegacyFix(TypedDict): @@ -458,11 +740,11 @@ class LegacyFix(TypedDict): ) async def code_action(params: CodeActionParams) -> list[CodeAction] | None: """LSP handler for textDocument/codeAction request.""" - document = LSP_SERVER.workspace.get_text_document(params.text_document.uri) + document = Document.from_uri(params.text_document.uri) settings = _get_settings_by_document(document.path) - if utils.is_stdlib_file(document.path): + if document.is_stdlib_file(): # Don't format standard library files. # Publishing empty diagnostics clears the entry. return None @@ -478,14 +760,14 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: and len(params.context.only) == 1 and kind in params.context.only ): - edits = await _fix_document_impl(document, only="I001") - if edits: + workspace_edit = await _fix_document_impl(document, only="I001") + if workspace_edit: return [ CodeAction( title="Ruff: Organize Imports", kind=kind, data=params.text_document.uri, - edit=_create_workspace_edits(document, edits), + edit=workspace_edit, diagnostics=[], ) ] @@ -503,14 +785,14 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: and len(params.context.only) == 1 and kind in params.context.only ): - edits = await _fix_document_impl(document) - if edits: + workspace_edit = await _fix_document_impl(document) + if workspace_edit: return [ CodeAction( title="Ruff: Fix All", kind=kind, data=params.text_document.uri, - edit=_create_workspace_edits(document, edits), + edit=workspace_edit, diagnostics=[ diagnostic for diagnostic in params.context.diagnostics @@ -528,6 +810,9 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: # Add "Ruff: Autofix" for every fixable diagnostic. if settings.get("codeAction", {}).get("fixViolation", {}).get("enable", True): if not params.context.only or CodeActionKind.QuickFix in params.context.only: + document = Document.from_text_document( + LSP_SERVER.workspace.get_text_document(params.text_document.uri) + ) for diagnostic in params.context.diagnostics: if diagnostic.source == "Ruff": fix = cast(DiagnosticData, diagnostic.data).get("fix") @@ -553,13 +838,16 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: # Add "Disable for this line" for every diagnostic. if settings.get("codeAction", {}).get("disableRuleComment", {}).get("enable", True): if not params.context.only or CodeActionKind.QuickFix in params.context.only: + document = Document.from_text_document( + LSP_SERVER.workspace.get_text_document(params.text_document.uri) + ) lines: list[str] | None = None for diagnostic in params.context.diagnostics: if diagnostic.source == "Ruff": noqa_row = cast(DiagnosticData, diagnostic.data).get("noqa_row") if noqa_row is not None: if lines is None: - lines = document.lines + lines = document.source.splitlines(keepends=True) line = lines[noqa_row - 1].rstrip("\r\n") match = NOQA_REGEX.search(line) @@ -621,14 +909,14 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: ), ) else: - edits = await _fix_document_impl(document, only="I001") - if edits: + workspace_edit = await _fix_document_impl(document, only="I001") + if workspace_edit: actions.append( CodeAction( title="Ruff: Organize Imports", kind=CodeActionKind.SourceOrganizeImports, data=params.text_document.uri, - edit=_create_workspace_edits(document, edits), + edit=workspace_edit, diagnostics=[], ), ) @@ -649,14 +937,14 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: ), ) else: - edits = await _fix_document_impl(document) - if edits: + workspace_edit = await _fix_document_impl(document) + if workspace_edit: actions.append( CodeAction( title="Ruff: Fix All", kind=CodeActionKind.SourceFixAll, data=params.text_document.uri, - edit=_create_workspace_edits(document, edits), + edit=workspace_edit, diagnostics=[ diagnostic for diagnostic in params.context.diagnostics @@ -673,7 +961,8 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None: @LSP_SERVER.feature(CODE_ACTION_RESOLVE) async def resolve_code_action(params: CodeAction) -> CodeAction: """LSP handler for codeAction/resolve request.""" - document = LSP_SERVER.workspace.get_text_document(cast(str, params.data)) + # We set the `data` field to the document URI during codeAction request. + document = Document.from_uri(cast(str, params.data)) settings = _get_settings_by_document(document.path) @@ -682,51 +971,49 @@ async def resolve_code_action(params: CodeAction) -> CodeAction: f"{CodeActionKind.SourceOrganizeImports.value}.ruff", ): # Generate the "Ruff: Organize Imports" edit - results = await _fix_document_impl(document, only="I001") - params.edit = _create_workspace_edits(document, results) + params.edit = await _fix_document_impl(document, only="I001") elif settings["fixAll"] and params.kind in ( CodeActionKind.SourceFixAll, f"{CodeActionKind.SourceFixAll.value}.ruff", ): # Generate the "Ruff: Fix All" edit. - results = await _fix_document_impl(document) - params.edit = _create_workspace_edits(document, results) + params.edit = await _fix_document_impl(document) return params @LSP_SERVER.command("ruff.applyAutofix") -async def apply_autofix(arguments: tuple[TextDocument]): +async def apply_autofix(ls: server.LanguageServer, arguments: tuple[TextDocument]): uri = arguments[0]["uri"] - text_document = LSP_SERVER.workspace.get_text_document(uri) - results = await _fix_document_impl(text_document) - LSP_SERVER.apply_edit( - _create_workspace_edits(text_document, results), - "Ruff: Fix all auto-fixable problems", - ) + document = Document.from_uri(uri) + workspace_edit = await _fix_document_impl(document) + if workspace_edit is None: + return + ls.apply_edit(workspace_edit, "Ruff: Fix all auto-fixable problems") @LSP_SERVER.command("ruff.applyOrganizeImports") -async def apply_organize_imports(arguments: tuple[TextDocument]): +async def apply_organize_imports( + ls: server.LanguageServer, arguments: tuple[TextDocument] +): uri = arguments[0]["uri"] - text_document = LSP_SERVER.workspace.get_text_document(uri) - results = await _fix_document_impl(text_document, only="I001") - LSP_SERVER.apply_edit( - _create_workspace_edits(text_document, results), - "Ruff: Format imports", - ) + document = Document.from_uri(uri) + workspace_edit = await _fix_document_impl(document, only="I001") + if workspace_edit is None: + return + ls.apply_edit(workspace_edit, "Ruff: Format imports") @LSP_SERVER.command("ruff.applyFormat") async def apply_format(arguments: tuple[TextDocument]): uri = arguments[0]["uri"] - text_document = LSP_SERVER.workspace.get_text_document(uri) - results = await _format_document_impl(text_document) - LSP_SERVER.apply_edit( - _create_workspace_edits(text_document, results), - "Ruff: Format document", - ) + document = Document.from_uri(uri) + results = await _run_format_on_document(document) + workspace_edit = _result_to_workspace_edit(document, results) + if workspace_edit is None: + return + LSP_SERVER.apply_edit(workspace_edit, "Ruff: Format document") if RUFF_EXPERIMENTAL_FORMATTER: @@ -736,47 +1023,116 @@ async def format_document( ls: server.LanguageServer, params: DocumentFormattingParams, ) -> list[TextEdit] | None: - uri = params.text_document.uri - document = ls.workspace.get_text_document(uri) - return await _format_document_impl(document) - - -async def _format_document_impl( - document: workspace.TextDocument, -) -> list[TextEdit]: - result = await _run_format_on_document(document) - return _result_to_edits(document, result) + # For a Jupyter Notebook, this request can only format a single cell as the + # request itself can only act on a text document. A cell in a Notebook is + # represented as a text document. + document = Document.from_text_document( + ls.workspace.get_text_document(params.text_document.uri) + ) + result = await _run_format_on_document(document) + if result is None: + return None + return _fixed_source_to_edits( + original_source=document.source, + fixed_source=result.stdout.decode("utf-8"), + is_notebook_file=document.is_notebook_file(), + ) async def _fix_document_impl( - document: workspace.TextDocument, + document: Document, *, only: str | None = None, -) -> list[TextEdit]: - result = await _run_check_on_document(document, extra_args=["--fix"], only=only) - return _result_to_edits(document, result) +) -> WorkspaceEdit | None: + result = await _run_check_on_document( + document, + extra_args=["--fix"], + only=only, + ) + return _result_to_workspace_edit(document, result) -def _result_to_edits( - document: workspace.TextDocument, - result: RunResult | None, -) -> list[TextEdit]: +def _result_to_workspace_edit( + document: Document, result: RunResult | None +) -> WorkspaceEdit | None: + """Converts a run result to a WorkspaceEdit.""" if result is None: - return [] + return None + + if document.kind is DocumentKind.Text: + edits = _fixed_source_to_edits( + original_source=document.source, + fixed_source=result.stdout.decode("utf-8"), + is_notebook_file=document.is_notebook_file(), + ) + return WorkspaceEdit( + document_changes=[ + _create_text_document_edit(document.uri, document.version, edits) + ] + ) + elif document.kind is DocumentKind.Notebook: + notebook_document = LSP_SERVER.workspace.get_notebook_document( + notebook_uri=document.uri + ) + if notebook_document is None: + log_warning(f"No notebook document found for {document.uri!r}") + return None + + output_notebook_cells = cast( + Notebook, json.loads(result.stdout.decode("utf-8")) + )["cells"] + if len(output_notebook_cells) != len(notebook_document.cells): + log_warning( + f"Number of cells in the output notebook doesn't match the number of " + f"cells in the input notebook. Input: {len(notebook_document.cells)}, " + f"Output: {len(output_notebook_cells)}" + ) + return None + + cell_document_changes: list[TextDocumentEdit] = [] + for cell_idx, cell in enumerate(notebook_document.cells): + if cell.kind is not NotebookCellKind.Code: + continue + cell_document = LSP_SERVER.workspace.get_text_document(cell.document) + edits = _fixed_source_to_edits( + original_source=cell_document.source, + fixed_source=output_notebook_cells[cell_idx]["source"], + is_notebook_file=True, + ) + cell_document_changes.append( + _create_text_document_edit( + cell_document.uri, + cell_document.version, + edits, + ) + ) + + return WorkspaceEdit(document_changes=list(cell_document_changes)) + else: + assert_never(document.kind) - if not result.stdout: - return [] - new_source = _match_line_endings(document, result.stdout.decode("utf-8")) +def _fixed_source_to_edits( + *, original_source: str, fixed_source: str | list[str], is_notebook_file: bool +) -> list[TextEdit]: + """Converts the fixed source to a list of TextEdits. + + If the fixed source is a list of strings, it is joined together to form a single + string with an assumption that the line endings are part of the strings itself. + """ + if isinstance(fixed_source, list): + fixed_source = "".join(fixed_source) + + new_source = _match_line_endings(original_source, fixed_source) # Skip last line ending in a notebook cell. - if document.uri.startswith("vscode-notebook-cell"): + if is_notebook_file: if new_source.endswith("\r\n"): new_source = new_source[:-2] elif new_source.endswith("\n"): new_source = new_source[:-1] - if new_source == document.source: + if new_source == original_source: return [] return [ @@ -790,30 +1146,25 @@ def _result_to_edits( ] -def _create_workspace_edits( - document: workspace.TextDocument, - edits: Sequence[TextEdit | AnnotatedTextEdit], -) -> WorkspaceEdit: - return WorkspaceEdit( - document_changes=[ - TextDocumentEdit( - text_document=OptionalVersionedTextDocumentIdentifier( - uri=document.uri, - version=0 if document.version is None else document.version, - ), - edits=list(edits), - ) - ], +def _create_text_document_edit( + uri: str, version: int | None, edits: Sequence[TextEdit | AnnotatedTextEdit] +) -> TextDocumentEdit: + return TextDocumentEdit( + text_document=OptionalVersionedTextDocumentIdentifier( + uri=uri, + version=0 if version is None else version, + ), + edits=list(edits), ) -def _create_workspace_edit(document: workspace.TextDocument, fix: Fix) -> WorkspaceEdit: +def _create_workspace_edit(document: Document, fix: Fix) -> WorkspaceEdit: return WorkspaceEdit( document_changes=[ TextDocumentEdit( text_document=OptionalVersionedTextDocumentIdentifier( uri=document.uri, - version=0 if document.version is None else document.version, + version=document.version, ), edits=[ TextEdit( @@ -849,13 +1200,13 @@ def _get_line_endings(text: str) -> str | None: return None # No line ending found -def _match_line_endings(document: workspace.TextDocument, text: str) -> str: +def _match_line_endings(original_source: str, fixed_source: str) -> str: """Ensures that the edited text line endings matches the document line endings.""" - expected = _get_line_endings(document.source) - actual = _get_line_endings(text) + expected = _get_line_endings(original_source) + actual = _get_line_endings(fixed_source) if actual is None or expected is None or actual == expected: - return text - return text.replace(actual, expected) + return fixed_source + return fixed_source.replace(actual, expected) async def run_path( @@ -1148,17 +1499,13 @@ def _executable_version(executable: str) -> Version: async def _run_check_on_document( - document: workspace.TextDocument, + document: Document, *, extra_args: Sequence[str] = [], only: str | None = None, ) -> RunResult | None: """Runs the Ruff `check` subcommand on the given document.""" - if str(document.uri).startswith("vscode-notebook-cell"): - # Skip notebook cells - return None - - if utils.is_stdlib_file(document.path): + if document.is_stdlib_file(): log_warning(f"Skipping standard library file: {document.path}") return None @@ -1202,13 +1549,9 @@ async def _run_check_on_document( ) -async def _run_format_on_document(document: workspace.TextDocument) -> RunResult | None: +async def _run_format_on_document(document: Document) -> RunResult | None: """Runs the Ruff `format` subcommand on the given document.""" - if str(document.uri).startswith("vscode-notebook-cell"): - # Skip notebook cells - return None - - if utils.is_stdlib_file(document.path): + if document.is_stdlib_file(): log_warning(f"Skipping standard library file: {document.path}") return None diff --git a/ruff_lsp/settings.py b/ruff_lsp/settings.py index 164ba6a..5d99ade 100644 --- a/ruff_lsp/settings.py +++ b/ruff_lsp/settings.py @@ -1,8 +1,19 @@ from __future__ import annotations +import enum + from typing_extensions import Literal, TypedDict -Run = Literal["onSave", "onType"] + +@enum.unique +class Run(str, enum.Enum): + """When to run Ruff.""" + + OnType = "onType" + """Run Ruff on every keystroke.""" + + OnSave = "onSave" + """Run Ruff on save.""" class UserSettings(TypedDict, total=False): @@ -94,8 +105,8 @@ def lint_args(settings: UserSettings) -> list[str]: def lint_run(settings: UserSettings) -> Run: """Get the `lint.run` setting from the user settings.""" if "lint" in settings and "run" in settings["lint"]: - return settings["lint"]["run"] + return Run(settings["lint"]["run"]) elif "run" in settings: - return settings["run"] + return Run(settings["run"]) else: - return "onType" + return Run.OnType diff --git a/tests/test_format.py b/tests/test_format.py index acfa52c..6285406 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -8,7 +8,9 @@ from ruff_lsp.server import ( VERSION_REQUIREMENT_FORMATTER, - _format_document_impl, + Document, + _fixed_source_to_edits, + _run_format_on_document, ) from tests.client import utils @@ -27,7 +29,7 @@ async def test_format(tmp_path, ruff_version: Version): uri = utils.as_uri(str(test_file)) workspace = Workspace(str(tmp_path)) - document = workspace.get_text_document(uri) + document = Document.from_text_document(workspace.get_text_document(uri)) handle_unsupported = ( pytest.raises(RuntimeError, match=f"Ruff .* required, but found {ruff_version}") @@ -36,6 +38,11 @@ async def test_format(tmp_path, ruff_version: Version): ) with handle_unsupported: - result = await _format_document_impl(document) - [edit] = result + result = await _run_format_on_document(document) + assert result is not None + [edit] = _fixed_source_to_edits( + original_source=document.source, + fixed_source=result.stdout.decode("utf-8"), + is_notebook_file=document.is_notebook_file(), + ) assert edit.new_text == expected diff --git a/tests/test_server.py b/tests/test_server.py index beb5e5f..2b67c4e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -76,6 +76,7 @@ def _handler(params): "message": "Remove unused import: `sys`", }, "noqa_row": 1, + "cell": None, }, "message": "`sys` imported but unused", "range": { @@ -91,7 +92,7 @@ def _handler(params): "codeDescription": { "href": "https://docs.astral.sh/ruff/rules/undefined-name" }, - "data": {"fix": None, "noqa_row": 3}, + "data": {"fix": None, "noqa_row": 3, "cell": None}, "message": "Undefined name `x`", "range": { "end": {"character": 7, "line": 2}, @@ -165,6 +166,7 @@ def _handler(params): "message": "Remove unused import: `sys`", }, "noqa_row": 1, + "cell": None, }, "message": "`sys` imported but unused", "range": { @@ -180,7 +182,7 @@ def _handler(params): "codeDescription": { "href": "https://docs.astral.sh/ruff/rules/undefined-name" }, - "data": {"fix": None, "noqa_row": 3}, + "data": {"fix": None, "noqa_row": 3, "cell": None}, "message": "Undefined name `x`", "range": { "end": {"character": 7, "line": 2},