diff --git a/doc/conf.py b/doc/conf.py index 3e210a3fa..a01b6b85d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -39,6 +39,7 @@ "sphinxext_altair.altairplot", "sphinxext.altairgallery", "sphinxext.schematable", + "sphinxext.code_ref", "sphinx_copybutton", "sphinx_design", ] diff --git a/doc/user_guide/customization.rst b/doc/user_guide/customization.rst index c068d62e3..ec8df035e 100644 --- a/doc/user_guide/customization.rst +++ b/doc/user_guide/customization.rst @@ -787,10 +787,16 @@ If you would like to use any theme just for a single chart, you can use the with alt.themes.enable('default'): spec = chart.to_json() +Built-in Themes +~~~~~~~~~~~~~~~ Currently Altair does not offer many built-in themes, but we plan to add more options in the future. -See `Vega Theme Test`_ for an interactive demo of themes inherited from `Vega Themes`_. +You can get a feel for the themes inherited from `Vega Themes`_ via *Vega-Altair Theme Test* below: + +.. altair-theme:: tests.altair_theme_test.alt_theme_test + :fold: + :summary: Show Vega-Altair Theme Test Defining a Custom Theme ~~~~~~~~~~~~~~~~~~~~~~~ @@ -843,6 +849,13 @@ If you want to restore the default theme, use: alt.themes.enable('default') +When experimenting with your theme, you can use the code below to see how +it translates across a range of charts/marks: + +.. altair-code-ref:: tests.altair_theme_test.alt_theme_test + :fold: + :summary: Show Vega-Altair Theme Test code + For more ideas on themes, see the `Vega Themes`_ repository. @@ -889,5 +902,4 @@ The configured localization settings persist upon saving. alt.renderers.set_embed_options(format_locale="en-US", time_format_locale="en-US") .. _Vega Themes: https://github.com/vega/vega-themes/ -.. _`D3's localization support`: https://d3-wiki.readthedocs.io/zh-cn/master/Localization/ -.. _Vega Theme Test: https://vega.github.io/vega-themes/?renderer=canvas \ No newline at end of file +.. _`D3's localization support`: https://d3-wiki.readthedocs.io/zh-cn/master/Localization/ \ No newline at end of file diff --git a/sphinxext/code_ref.py b/sphinxext/code_ref.py new file mode 100644 index 000000000..6371713de --- /dev/null +++ b/sphinxext/code_ref.py @@ -0,0 +1,330 @@ +"""Sphinx extension providing formatted code blocks, referencing some function.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, get_args + +from docutils import nodes +from docutils.parsers.rst import directives +from sphinx.util.docutils import SphinxDirective +from sphinx.util.parsing import nested_parse_to_nodes + +from altair.vegalite.v5.schema._typing import VegaThemes +from tools.codemod import extract_func_def, extract_func_def_embed + +if TYPE_CHECKING: + import sys + from typing import ( + Any, + Callable, + ClassVar, + Iterable, + Iterator, + Mapping, + Sequence, + TypeVar, + Union, + ) + + from docutils.parsers.rst.states import RSTState, RSTStateMachine + from docutils.statemachine import StringList + from sphinx.application import Sphinx + + if sys.version_info >= (3, 12): + from typing import TypeAliasType + else: + from typing_extensions import TypeAliasType + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + T = TypeVar("T") + OneOrIter = TypeAliasType("OneOrIter", Union[T, Iterable[T]], type_params=(T,)) + +_OutputShort: TypeAlias = Literal["code", "plot"] +_OutputLong: TypeAlias = Literal["code-block", "altair-plot"] +_OUTPUT_REMAP: Mapping[_OutputShort, _OutputLong] = { + "code": "code-block", + "plot": "altair-plot", +} +_Option: TypeAlias = Literal["output", "fold", "summary"] + +_PYSCRIPT_URL_FMT = "https://pyscript.net/releases/{0}/core.js" +_PYSCRIPT_VERSION = "2024.10.1" +_PYSCRIPT_URL = _PYSCRIPT_URL_FMT.format(_PYSCRIPT_VERSION) + + +def validate_output(output: Any) -> _OutputLong: + output = output.strip().lower() + if output not in {"plot", "code"}: + msg = f":output: option must be one of {get_args(_OutputShort)!r}" + raise TypeError(msg) + else: + short: _OutputShort = output + return _OUTPUT_REMAP[short] + + +def validate_packages(packages: Any) -> str: + if packages is None: + return '["altair"]' + else: + split = [pkg.strip() for pkg in packages.split(",")] + if len(split) == 1: + return f'["{split[0]}"]' + else: + return f'[{",".join(split)}]' + + +def raw_html(text: str, /) -> nodes.raw: + return nodes.raw("", text, format="html") + + +def maybe_details( + parsed: Iterable[nodes.Node], options: dict[_Option, Any], *, default_summary: str +) -> Sequence[nodes.Node]: + """ + Wrap ``parsed`` in a folding `details`_ block if requested. + + Parameters + ---------- + parsed + Target nodes that have been processed. + options + Optional arguments provided to ``.. altair-code-ref::``. + + .. note:: + If no relevant options are specified, + ``parsed`` is returned unchanged. + + default_summary + Label text used when **only** specifying ``:fold:``. + + .. _details: + https://developer.mozilla.org/en-US/docs/Web/HTML/Element/details + """ + + def gen() -> Iterator[nodes.Node]: + if {"fold", "summary"}.isdisjoint(options.keys()): + yield from parsed + else: + summary = options.get("summary", default_summary) + yield raw_html(f"

{summary}") + yield from parsed + yield raw_html("

") + + return list(gen()) + + +def theme_names() -> tuple[Sequence[str], Sequence[str]]: + names: set[VegaThemes] = set(get_args(VegaThemes)) + carbon = {nm for nm in names if nm.startswith("carbon")} + return ["default", *sorted(names - carbon)], sorted(carbon) + + +def option(label: str, value: str | None = None, /) -> nodes.raw: + s = f"\n") + + +def optgroup(label: str, *options: OneOrIter[nodes.raw]) -> Iterator[nodes.raw]: + yield raw_html(f"\n") + for opt in options: + if isinstance(opt, nodes.raw): + yield opt + else: + yield from opt + yield raw_html("\n") + + +def dropdown( + id: str, label: str | None, extra_select: str, *options: OneOrIter[nodes.raw] +) -> Iterator[nodes.raw]: + if label: + yield raw_html(f"\n") + select_text = f"\n") + + +def pyscript( + packages: str, target_div_id: str, loading_label: str, py_code: str +) -> Iterator[nodes.raw]: + PY = "py" + LB, RB = "{", "}" + packages = f""""packages":{packages}""" + yield raw_html(f"
{loading_label}
\n") + yield raw_html(f"\n") + + +def _before_code(refresh_name: str, select_id: str, target_div_id: str) -> str: + INDENT = " " * 4 + return ( + f"from js import document\n" + f"from pyscript import display\n" + f"import altair as alt\n\n" + f"def {refresh_name}(*args):\n" + f"{INDENT}selected = document.getElementById({select_id!r}).value\n" + f"{INDENT}alt.renderers.set_embed_options(theme=selected)\n" + f"{INDENT}display(chart, append=False, target={target_div_id!r})\n" + ) + + +class ThemeDirective(SphinxDirective): + """ + Theme preview directive. + + Similar to ``CodeRefDirective``, but uses `PyScript`_ to access the browser. + + .. _PyScript: + https://pyscript.net/ + """ + + has_content: ClassVar[Literal[False]] = False + required_arguments: ClassVar[Literal[1]] = 1 + option_spec = { + "packages": validate_packages, + "dropdown-label": directives.unchanged, + "loading-label": directives.unchanged, + "fold": directives.flag, + "summary": directives.unchanged_required, + } + + def run(self) -> Sequence[nodes.Node]: + results: list[nodes.Node] = [] + SELECT_ID = "embed_theme" + REFRESH_NAME = "apply_embed_input" + TARGET_DIV_ID = "render_altair" + standard_names, carbon_names = theme_names() + + qual_name = self.arguments[0] + module_name, func_name = qual_name.rsplit(".", 1) + dropdown_label = self.options.get("dropdown-label", "Select theme:") + loading_label = self.options.get("loading-label", "loading...") + packages: str = self.options.get("packages", validate_packages(None)) + + results.append(raw_html("

\n")) + results.extend( + dropdown( + SELECT_ID, + dropdown_label, + f"py-input={REFRESH_NAME!r}", + (option(nm) for nm in standard_names), + optgroup("Carbon", (option(nm) for nm in carbon_names)), + ) + ) + py_code = extract_func_def_embed( + module_name, + func_name, + before=_before_code(REFRESH_NAME, SELECT_ID, TARGET_DIV_ID), + after=f"{REFRESH_NAME}()", + assign_to="chart", + indent=4, + ) + results.extend( + pyscript(packages, TARGET_DIV_ID, loading_label, py_code=py_code) + ) + results.append(raw_html("

\n")) + return maybe_details( + results, self.options, default_summary="Show Vega-Altair Theme Test" + ) + + +class PyScriptDirective(SphinxDirective): + """Placeholder for non-theme related directive.""" + + has_content: ClassVar[Literal[False]] = False + option_spec = {"packages": directives.unchanged} + + def run(self) -> Sequence[nodes.Node]: + raise NotImplementedError + + +class CodeRefDirective(SphinxDirective): + """ + Formatted code block, referencing the contents of a function definition. + + Options: + + .. altair-code-ref:: + :output: [code, plot] + :fold: flag + :summary: str + + Examples + -------- + Reference a function, generating a code block: + + .. altair-code-ref:: package.module.function + + Wrap the code block in a collapsible `details`_ tag: + + .. altair-code-ref:: package.module.function + :fold: + + Override default ``"Show code"`` `details`_ summary: + + .. altair-code-ref:: package.module.function + :fold: + :summary: Look here! + + Use `altair-plot`_ instead of a code block: + + .. altair-code-ref:: package.module.function + :output: plot + + .. note:: + Using `altair-plot`_ currently ignores the other options. + + .. _details: + https://developer.mozilla.org/en-US/docs/Web/HTML/Element/details + .. _altair-plot: + https://github.com/vega/sphinxext-altair + """ + + has_content: ClassVar[Literal[False]] = False + required_arguments: ClassVar[Literal[1]] = 1 + option_spec: ClassVar[dict[_Option, Callable[[str], Any]]] = { + "output": validate_output, + "fold": directives.flag, + "summary": directives.unchanged_required, + } + + def __init__( + self, + name: str, + arguments: list[str], + options: dict[_Option, Any], + content: StringList, + lineno: int, + content_offset: int, + block_text: str, + state: RSTState, + state_machine: RSTStateMachine, + ) -> None: + super().__init__(name, arguments, options, content, lineno, content_offset, block_text, state, state_machine) # fmt: skip + self.options: dict[_Option, Any] + + def run(self) -> Sequence[nodes.Node]: + qual_name = self.arguments[0] + module_name, func_name = qual_name.rsplit(".", 1) + output: _OutputLong = self.options.get("output", "code-block") + content = extract_func_def(module_name, func_name, output=output) + parsed = nested_parse_to_nodes(self.state, content) + return maybe_details(parsed, self.options, default_summary="Show code") + + +def setup(app: Sphinx) -> None: + app.add_directive_to_domain("py", "altair-code-ref", CodeRefDirective) + app.add_js_file(_PYSCRIPT_URL, loading_method="defer", type="module") + # app.add_directive("altair-pyscript", PyScriptDirective) + app.add_directive("altair-theme", ThemeDirective) diff --git a/tests/altair_theme_test.py b/tests/altair_theme_test.py new file mode 100644 index 000000000..b9114baec --- /dev/null +++ b/tests/altair_theme_test.py @@ -0,0 +1,141 @@ +# ruff: noqa: E711 +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from altair.typing import ChartType + + +def alt_theme_test() -> ChartType: + import altair as alt + + VEGA_DATASETS = "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/" + us_10m = f"{VEGA_DATASETS}us-10m.json" + unemployment = f"{VEGA_DATASETS}unemployment.tsv" + movies = f"{VEGA_DATASETS}movies.json" + barley = f"{VEGA_DATASETS}barley.json" + iowa_electricity = f"{VEGA_DATASETS}iowa-electricity.csv" + common_data = alt.InlineData( + [ + {"Index": 1, "Value": 28, "Position": 1, "Category": "A"}, + {"Index": 2, "Value": 55, "Position": 2, "Category": "A"}, + {"Index": 3, "Value": 43, "Position": 3, "Category": "A"}, + {"Index": 4, "Value": 91, "Position": 4, "Category": "A"}, + {"Index": 5, "Value": 81, "Position": 5, "Category": "A"}, + {"Index": 6, "Value": 53, "Position": 6, "Category": "A"}, + {"Index": 7, "Value": 19, "Position": 1, "Category": "B"}, + {"Index": 8, "Value": 87, "Position": 2, "Category": "B"}, + {"Index": 9, "Value": 52, "Position": 3, "Category": "B"}, + {"Index": 10, "Value": 48, "Position": 4, "Category": "B"}, + {"Index": 11, "Value": 24, "Position": 5, "Category": "B"}, + {"Index": 12, "Value": 49, "Position": 6, "Category": "B"}, + {"Index": 13, "Value": 87, "Position": 1, "Category": "C"}, + {"Index": 14, "Value": 66, "Position": 2, "Category": "C"}, + {"Index": 15, "Value": 17, "Position": 3, "Category": "C"}, + {"Index": 16, "Value": 27, "Position": 4, "Category": "C"}, + {"Index": 17, "Value": 68, "Position": 5, "Category": "C"}, + {"Index": 18, "Value": 16, "Position": 6, "Category": "C"}, + ] + ) + + HEIGHT_SMALL = 140 + STANDARD = 180 + WIDTH_GEO = int(STANDARD * 1.667) + + bar = ( + alt.Chart(common_data, height=HEIGHT_SMALL, width=STANDARD, title="Bar") + .mark_bar() + .encode( + x=alt.X("Index:O").axis(offset=1), y=alt.Y("Value:Q"), tooltip="Value:Q" + ) + .transform_filter(alt.datum["Index"] <= 9) + ) + line = ( + alt.Chart(common_data, height=HEIGHT_SMALL, width=STANDARD, title="Line") + .mark_line() + .encode( + x=alt.X("Position:O").axis(grid=False), + y=alt.Y("Value:Q").axis(grid=False), + color=alt.Color("Category:N").legend(None), + tooltip=["Index:O", "Value:Q", "Position:O", "Category:N"], + ) + ) + point_shape = ( + alt.Chart( + common_data, height=HEIGHT_SMALL, width=STANDARD, title="Point (Shape)" + ) + .mark_point() + .encode( + x=alt.X("Position:O").axis(grid=False), + y=alt.Y("Value:Q").axis(grid=False), + shape=alt.Shape("Category:N").legend(None), + color=alt.Color("Category:N").legend(None), + tooltip=["Index:O", "Value:Q", "Position:O", "Category:N"], + ) + ) + point = ( + alt.Chart(movies, height=STANDARD, width=STANDARD, title="Point") + .mark_point(tooltip=True) + .transform_filter(alt.datum["IMDB_Rating"] != None) + .transform_filter( + alt.FieldRangePredicate("Release_Date", [None, 2019], timeUnit="year") + ) + .transform_joinaggregate(Average_Rating="mean(IMDB_Rating)") + .transform_calculate( + Rating_Delta=alt.datum["IMDB_Rating"] - alt.datum.Average_Rating + ) + .encode( + x=alt.X("Release_Date:T").title("Release Date"), + y=alt.Y("Rating_Delta:Q").title("Rating Delta"), + color=alt.Color("Rating_Delta:Q").title("Rating Delta").scale(domainMid=0), + ) + ) + bar_stack = ( + alt.Chart(barley, height=STANDARD, width=STANDARD, title="Bar (Stacked)") + .mark_bar(tooltip=True) + .encode( + x="sum(yield):Q", + y=alt.Y("variety:N"), + color=alt.Color("site:N").legend(orient="bottom", columns=2), + ) + ) + area = ( + alt.Chart(iowa_electricity, height=STANDARD, width=STANDARD, title="Area") + .mark_area(tooltip=True) + .encode( + x=alt.X("year:T").title("Year"), + y=alt.Y("net_generation:Q") + .title("Share of net generation") + .stack("normalize") + .axis(format=".0%"), + color=alt.Color("source:N") + .title("Electricity source") + .legend(orient="bottom", columns=2), + ) + ) + geoshape = ( + alt.Chart( + alt.topo_feature(us_10m, "counties"), + height=STANDARD, + width=WIDTH_GEO, + title=alt.Title("Geoshape", subtitle="Unemployment rate per county"), + ) + .mark_geoshape(tooltip=True) + .encode(color="rate:Q") + .transform_lookup( + "id", alt.LookupData(alt.UrlData(unemployment), "id", ["rate"]) + ) + .project(type="albersUsa") + ) + + compound_chart = ( + (bar | line | point_shape) & (point | bar_stack) & (area | geoshape) + ).properties( + title=alt.Title( + "Vega-Altair Theme Test", + fontSize=20, + subtitle="Adapted from https://vega.github.io/vega-themes/", + ) + ) + return compound_chart diff --git a/tools/codemod.py b/tools/codemod.py new file mode 100644 index 000000000..8281e0732 --- /dev/null +++ b/tools/codemod.py @@ -0,0 +1,453 @@ +# ruff: noqa: D418 +from __future__ import annotations + +import ast +import subprocess +import sys +import textwrap +import warnings +from collections import deque +from importlib.util import find_spec +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterable, TypeVar, overload + +if sys.version_info >= (3, 12): + from typing import Protocol, TypeAliasType +else: + from typing_extensions import Protocol, TypeAliasType + +if TYPE_CHECKING: + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString + + from typing import ClassVar, Iterator, Literal + + +__all__ = ["extract_func_def", "extract_func_def_embed", "ruff", "ruff_inline_docs"] + +T = TypeVar("T") +OneOrIterV = TypeAliasType( + "OneOrIterV", + "T | Iterable[T] | Iterable[OneOrIterV[T]]", + type_params=(T,), +) +_Code = OneOrIterV[str] + + +def parse_module(name: str, /) -> ast.Module: + """ + Find absolute path and parse module into an ast. + + Use regular dotted import style, no `.py` suffix. + + Acceptable ``name``: + + altair.package.subpackage.etc + tools.____ + tests.____ + doc.____ + sphinxext.____ + """ + if (spec := find_spec(name)) and (origin := spec.origin): + return ast.parse(Path(origin).read_bytes()) + else: + raise FileNotFoundError(name) + + +if sys.version_info >= (3, 9): + + def unparse(obj: ast.AST, /) -> str: + return ast.unparse(obj) +else: + + def unparse(obj: ast.AST, /) -> str: + """ + Added in ``3.9``. + + https://docs.python.org/3/library/ast.html#ast.unparse + """ + # HACK: Will only be used during build/docs + # - This branch is just to satisfy linters + msg = f"Called `ast.unparse()` on {sys.version_info!r}\nFunction not available before {(3, 9)!r}" + warnings.warn(msg, ImportWarning, stacklevel=2) + return "" + + +def find_func_def(mod: ast.Module, fn_name: str, /) -> ast.FunctionDef: + """ + Return a function node matching ``fn_name``. + + Notes + ----- + Provides some extra type safety, over:: + + ast.Module.body: list[ast.stmt] + """ + for stmt in mod.body: + if isinstance(stmt, ast.FunctionDef) and stmt.name == fn_name: + return stmt + else: + continue + msg = f"Found no function named {fn_name!r}" + raise NotImplementedError(msg) + + +def validate_body(fn: ast.FunctionDef, /) -> tuple[list[ast.stmt], ast.expr]: + """ + Ensure function has inlined imports and a return statement. + + Returns:: + + (ast.FunctionDef.body[:-1], ast.Return.value) + """ + body = fn.body + first = body[0] + if not isinstance(first, (ast.Import, ast.ImportFrom)): + msg = ( + f"First statement in function must be an import, " + f"got {type(first).__name__!r}\n\n" + f"{unparse(first)!r}" + ) + raise TypeError(msg) + last = body.pop() + if not isinstance(last, ast.Return) or last.value is None: + body.append(last) + msg = ( + f"Last statement in function must return an expression, " + f"got {type(last).__name__!r}\n\n" + f"{unparse(last)!r}" + ) + raise TypeError(msg) + else: + return body, last.value + + +def iter_flatten(*elements: _Code) -> Iterator[str]: + for el in elements: + if not isinstance(el, str) and isinstance(el, Iterable): + yield from iter_flatten(*el) + elif isinstance(el, str): + yield el + else: + msg = ( + f"Expected all elements to eventually flatten to ``str``, " + f"but got: {type(el).__name__!r}\n\n" + f"{el!r}" + ) + raise TypeError(msg) + + +def iter_func_def_unparse( + module_name: str, + func_name: str, + /, + *, + return_transform: Literal["assign"] | None = None, + assign_to: str = "chart", +) -> Iterator[str]: + # Planning to add pyscript code before/after + # Then add `ruff check` to clean up duplicate imports (on the whole thing) + # Optional: assign the return to `assign_to` + # - Allows writing modular code that doesn't depend on the variable names in the original function + mod = parse_module(module_name) + fn = find_func_def(mod, func_name) + body, ret = validate_body(fn) + for node in body: + yield unparse(node) + yield "" + ret_value = unparse(ret) + if return_transform is None: + yield ret_value + elif return_transform == "assign": + yield f"{assign_to} = {ret_value}" + else: + msg = f"{return_transform=}" + raise NotImplementedError(msg) + + +def extract_func_def( + module_name: str, + func_name: str, + *, + format: bool = True, + output: Literal["altair-plot", "code-block", "str"] = "str", +) -> str: + """ + Extract the contents of a function for use as a code block. + + Parameters + ---------- + module_name + Absolute, dotted import style. + func_name + Name of function in ``module_name``. + format + Run through ``ruff format`` before returning. + output + Optionally, return embedded in an `rst` directive. + + Notes + ----- + - Functions must define all imports inline, to ensure they are propagated + - Must end with a single return statement + + Warning + ------- + Requires ``python>=3.9`` for `ast.unparse`_ + + Examples + -------- + Transform the contents of a function into a code block:: + + >>> extract_func_def("tests.altair_theme_test", "alt_theme_test", output="code-block") # doctest: +SKIP + + .. _ast.unparse: + https://docs.python.org/3.9/library/ast.html#ast.unparse + """ + if output not in {"altair-plot", "code-block", "str"}: + raise TypeError(output) + it = iter_func_def_unparse(module_name, func_name) + s = ruff_inline_docs.format(it) if format else "\n".join(it) + if output == "str": + return s + else: + return f".. {output}::\n\n{textwrap.indent(s, ' ' * 4)}\n" + + +def extract_func_def_embed( + module_name: str, + func_name: str, + /, + before: _Code | None = None, + after: _Code | None = None, + assign_to: str = "chart", + indent: int | None = None, +) -> str: + """ + Extract the contents of a function, wrapping with ``before`` and ``after``. + + The resulting code block is run through ``ruff`` to deduplicate imports + and apply consistent formatting. + + Parameters + ---------- + module_name + Absolute, dotted import style. + func_name + Name of function in ``module_name``. + before + Code inserted before ``func_name``. + after + Code inserted after ``func_name``. + assign_to + Variable name to use as the result of ``func_name``. + + .. note:: + Allows the ``after`` block to use a consistent reference. + indent + Optionally, prefix ``indent * " "`` to final block. + + .. note:: + Occurs **after** formatting, will not contribute to line length wrap. + """ + if before is None and after is None: + msg = ( + f"At least one additional code fragment should be provided, but:\n" + f"{before=}, {after=}\n\n" + f"Use {extract_func_def.__qualname__!r} instead." + ) + warnings.warn(msg, UserWarning, stacklevel=2) + unparsed = iter_func_def_unparse( + module_name, func_name, return_transform="assign", assign_to=assign_to + ) + parts = [p for p in (before, unparsed, after) if p is not None] + formatted = ruff_inline_docs(parts) + return textwrap.indent(formatted, " " * indent) if indent else formatted + + +class CodeMod(Protocol): + def __call__(self, *code: _Code) -> str: + """ + Transform some input into a single block of modified code. + + Parameters + ---------- + *code + Arbitrarily nested code fragments. + """ + ... + + def _join(self, code: _Code, *, sep: str = "\n") -> str: + """ + Concatenate any number of code fragments. + + All nested groups are unwrapped into a flat iterable. + """ + return sep.join(iter_flatten(code)) + + +class Ruff(CodeMod): + """ + Run `ruff`_ commands against code fragments or files. + + By default, uses the same config as `pyproject.toml`_. + + Parameters + ---------- + *extend_select + `rule codes`_ to use **on top of** the default config. + ignore + `rule codes`_ to `ignore`_. + skip_magic_traling_comma + Enables `skip-magic-trailing-comma`_ during formatting. + + .. note:: + + Only use on code that is changing indent-level + (e.g. unwrapping function contents). + + .. _ruff: + https://docs.astral.sh/ruff/ + .. _pyproject.toml: + https://github.com/vega/altair/blob/main/pyproject.toml + .. _rule codes: + https://docs.astral.sh/ruff/rules/ + .. _ignore: + https://docs.astral.sh/ruff/settings/#lint_ignore + .. _skip-magic-trailing-comma: + https://docs.astral.sh/ruff/settings/#format_skip-magic-trailing-comma + """ + + _stdin_args: ClassVar[tuple[LiteralString, ...]] = ( + "--stdin-filename", + "placeholder.py", + ) + _check_args: ClassVar[tuple[LiteralString, ...]] = ("--fix",) + + def __init__( + self, + *extend_select: str, + ignore: OneOrIterV[str] | None = None, + skip_magic_traling_comma: bool = False, + ) -> None: + self.check_args: deque[str] = deque(self._check_args) + self.format_args: deque[str] = deque() + for c in extend_select: + self.check_args.extend(("--extend-select", c)) + if ignore is not None: + self.check_args.extend( + ("--ignore", ",".join(s for s in iter_flatten(ignore))) + ) + if skip_magic_traling_comma: + self.format_args.extend( + ("--config", "format.skip-magic-trailing-comma = true") + ) + + def write_lint_format(self, fp: Path, code: _Code, /) -> None: + """ + Combined steps of writing, `ruff check`, `ruff format`. + + Parameters + ---------- + fp + Target file to write to + code + Some (potentially) nested code fragments. + + Notes + ----- + - `fp` is written to first, as the size before formatting will be the smallest + - Better utilizes `ruff` performance, rather than `python` str and io + """ + self.check(fp, code) + self.format(fp) + + @overload + def check(self, *code: _Code, decode: Literal[True] = ...) -> str: + """Fixes violations and returns fixed code.""" + + @overload + def check(self, *code: _Code, decode: Literal[False]) -> bytes: + """ + ``decode=False`` will return as ``bytes``. + + Helpful if piping to another command. + """ + + @overload + def check(self, _write_to: Path, /, *code: _Code) -> None: + """ + ``code`` is joined, written to provided path and then checked. + + No input returned. + """ + + def check(self, *code: Any, decode: bool = True) -> str | bytes | None: + """ + Check and fix ``ruff`` rule violations. + + All cases will join ``code`` to a single ``str``. + """ + base = "ruff", "check" + if isinstance(code[0], Path): + fp = code[0] + fp.write_text(self._join(code[1:]), encoding="utf-8") + subprocess.run((*base, fp, *self.check_args), check=True) + return None + r = subprocess.run( + (*base, *self.check_args, *self._stdin_args), + input=self._join(code).encode(), + check=True, + capture_output=True, + ) + return r.stdout.decode() if decode else r.stdout + + @overload + def format(self, *code: _Code) -> str: + """Format arbitrarily nested input as a single block.""" + + @overload + def format(self, _target_file: Path, /, *code: None) -> None: + """ + Format an existing file. + + Running on `win32` after writing lines will ensure ``LF`` is used, and not ``CRLF``: + + ruff format --diff --check _target_file + """ + + @overload + def format(self, _encoded_result: bytes, /, *code: None) -> str: + """Format the raw result of ``ruff.check``.""" + + def format(self, *code: Any) -> str | None: + """ + Format some input code, or an existing file. + + Returns decoded result unless formatting an existing file. + """ + base = "ruff", "format" + if len(code) == 1 and isinstance(code[0], Path): + subprocess.run((*base, code[0], *self.format_args), check=True) + return None + encoded = ( + code[0] + if len(code) == 1 and isinstance(code[0], bytes) + else self._join(code).encode() + ) + r = subprocess.run( + (*base, *self.format_args, *self._stdin_args), + input=encoded, + check=True, + capture_output=True, + ) + return r.stdout.decode() + + def __call__(self, *code: _Code) -> str: + return self.format(self.check(code, decode=False)) + + +ruff_inline_docs = Ruff(ignore="E711", skip_magic_traling_comma=True) +ruff = Ruff() diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 3b120bc77..0db8772e8 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -20,6 +20,7 @@ sys.path.insert(0, str(Path.cwd())) +from tools.codemod import ruff from tools.markup import rst_syntax_for_class from tools.schemapi import CodeSnippet, SchemaInfo, arg_kwds, arg_required_kwds, codegen from tools.schemapi.utils import ( @@ -31,8 +32,6 @@ import_typing_extensions, indent_docstring, resolve_references, - ruff_format_py, - ruff_write_lint_format_str, spell_literal, ) from tools.vega_expr import write_expr_module @@ -548,7 +547,7 @@ def copy_schemapi_util() -> None: dest.write(HEADER_COMMENT) dest.writelines(source.readlines()) if sys.platform == "win32": - ruff_format_py(destination_fp) + ruff.format(destination_fp) def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None: @@ -1024,7 +1023,7 @@ def vegalite_main(skip_download: bool = False) -> None: f"SCHEMA_VERSION = '{version}'\n", f"SCHEMA_URL = {schema_url(version)!r}\n", ] - ruff_write_lint_format_str(outfile, content) + ruff.write_lint_format(outfile, content) TypeAliasTracer.update_aliases(("Map", "Mapping[str, Any]")) @@ -1106,7 +1105,7 @@ def vegalite_main(skip_download: bool = False) -> None: # Write the pre-generated modules for fp, contents in files.items(): print(f"Writing\n {schemafile!s}\n ->{fp!s}") - ruff_write_lint_format_str(fp, contents) + ruff.write_lint_format(fp, contents) def generate_encoding_artifacts( diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py index 6bc7b1f4b..351c355d0 100644 --- a/tools/schemapi/utils.py +++ b/tools/schemapi/utils.py @@ -4,7 +4,6 @@ import json import re -import subprocess import sys import textwrap import urllib.parse @@ -19,13 +18,13 @@ Iterator, Literal, Mapping, - MutableSequence, Sequence, TypeVar, Union, overload, ) +from tools.codemod import ruff from tools.markup import RSTParseVegaLite, rst_syntax_for_class from tools.schemapi.schemapi import _resolve_references as resolve_references @@ -91,13 +90,6 @@ class _TypeAliasTracer: A format specifier to produce the `TypeAlias` name. Will be provided a `SchemaInfo.title` as a single positional argument. - *ruff_check - Optional [ruff rule codes](https://docs.astral.sh/ruff/rules/), - each prefixed with `--select ` and follow a `ruff check --fix ` call. - - If not provided, uses `[tool.ruff.lint.select]` from `pyproject.toml`. - ruff_format - Optional argument list supplied to [ruff format](https://docs.astral.sh/ruff/formatter/#ruff-format) Attributes ---------- @@ -111,12 +103,7 @@ class _TypeAliasTracer: Prefined import statements to appear at beginning of module. """ - def __init__( - self, - fmt: str = "{}_T", - *ruff_check: str, - ruff_format: Sequence[str] | None = None, - ) -> None: + def __init__(self, fmt: str = "{}_T") -> None: self.fmt: str = fmt self._literals: dict[str, str] = {} self._literals_invert: dict[str, str] = {} @@ -135,10 +122,6 @@ def __init__( import_typing_extensions((3, 10), "TypeAlias"), import_typing_extensions((3, 9), "Annotated", "get_args"), ) - self._cmd_check: list[str] = ["--fix"] - self._cmd_format: Sequence[str] = ruff_format or () - for c in ruff_check: - self._cmd_check.extend(("--extend-select", c)) def _update_literals(self, name: str, tp: str, /) -> None: """Produces an inverted index, to reuse a `Literal` when `SchemaInfo.title` is empty.""" @@ -223,13 +206,6 @@ def write_module( extra `tools.generate_schema_wrapper.TYPING_EXTRA`. """ - ruff_format: MutableSequence[str | Path] = ["ruff", "format", fp] - if self._cmd_format: - ruff_format.extend(self._cmd_format) - commands: tuple[Sequence[str | Path], ...] = ( - ["ruff", "check", fp, *self._cmd_check], - ruff_format, - ) static = (header, "\n", *self._imports, "\n\n") self.update_aliases(*sorted(self._literals.items(), key=itemgetter(0))) all_ = [*iter(self._aliases), *extra_all] @@ -238,10 +214,7 @@ def write_module( [f"__all__ = {all_}", "\n\n", extra], self.generate_aliases(), ) - fp.write_text("\n".join(it), encoding="utf-8") - for cmd in commands: - r = subprocess.run(cmd, check=True) - r.check_returncode() + ruff.write_lint_format(fp, it) @property def n_entries(self) -> int: @@ -997,49 +970,6 @@ def unwrap_literal(tp: str, /) -> str: return re.sub(r"Literal\[(.+)\]", r"\g<1>", tp) -def ruff_format_py(fp: Path, /, *extra_args: str) -> None: - """ - Format an existing file. - - Running on `win32` after writing lines will ensure "lf" is used before: - ```bash - ruff format --diff --check . - ``` - """ - cmd: MutableSequence[str | Path] = ["ruff", "format", fp] - if extra_args: - cmd.extend(extra_args) - r = subprocess.run(cmd, check=True) - r.check_returncode() - - -def ruff_write_lint_format_str( - fp: Path, code: str | Iterable[str], /, *, encoding: str = "utf-8" -) -> None: - """ - Combined steps of writing, `ruff check`, `ruff format`. - - Notes - ----- - - `fp` is written to first, as the size before formatting will be the smallest - - Better utilizes `ruff` performance, rather than `python` str and io - - `code` is no longer bound to `list` - - Encoding set as default - - `I001/2` are `isort` rules, to sort imports. - """ - commands: Iterable[Sequence[str | Path]] = ( - ["ruff", "check", fp, "--fix"], - ["ruff", "check", fp, "--fix", "--select", "I001", "--select", "I002"], - ) - if not isinstance(code, str): - code = "\n".join(code) - fp.write_text(code, encoding=encoding) - for cmd in commands: - r = subprocess.run(cmd, check=True) - r.check_returncode() - ruff_format_py(fp) - - def import_type_checking(*imports: str) -> str: """Write an `if TYPE_CHECKING` block.""" imps = "\n".join(f" {s}" for s in imports) @@ -1066,7 +996,7 @@ def import_typing_extensions( ) -TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T", "I001", "I002") +TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T") """An instance of `_TypeAliasTracer`. Collects a cache of unique `Literal` types used globally. diff --git a/tools/update_init_file.py b/tools/update_init_file.py index c1831093a..0592032f0 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from tools.schemapi.utils import ruff_write_lint_format_str +from tools.codemod import ruff _TYPING_CONSTRUCTS = { te.TypeAlias, @@ -74,7 +74,7 @@ def update__all__variable() -> None: ] # Write new version of altair/__init__.py # Format file content with ruff - ruff_write_lint_format_str(init_path, new_lines) + ruff.write_lint_format(init_path, new_lines) def relevant_attributes(namespace: dict[str, t.Any], /) -> list[str]: diff --git a/tools/vega_expr.py b/tools/vega_expr.py index ce87cb2fb..66d6287fb 100644 --- a/tools/vega_expr.py +++ b/tools/vega_expr.py @@ -29,12 +29,10 @@ overload, ) +from tools.codemod import ruff from tools.markup import RSTParse, Token, read_ast_tokens from tools.markup import RSTRenderer as _RSTRenderer from tools.schemapi.schemapi import SchemaBase as _SchemaBase -from tools.schemapi.utils import ( - ruff_write_lint_format_str as _ruff_write_lint_format_str, -) if TYPE_CHECKING: import sys @@ -977,4 +975,4 @@ def write_expr_module(version: str, output: Path, *, header: str) -> None: [MODULE_POST], ) print(f"Generating\n {url!s}\n ->{output!s}") - _ruff_write_lint_format_str(output, contents) + ruff.write_lint_format(output, contents)