diff --git a/doc/conf.py b/doc/conf.py
index 3e210a3fa..a01b6b85d 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -39,6 +39,7 @@
"sphinxext_altair.altairplot",
"sphinxext.altairgallery",
"sphinxext.schematable",
+ "sphinxext.code_ref",
"sphinx_copybutton",
"sphinx_design",
]
diff --git a/doc/user_guide/customization.rst b/doc/user_guide/customization.rst
index c068d62e3..ec8df035e 100644
--- a/doc/user_guide/customization.rst
+++ b/doc/user_guide/customization.rst
@@ -787,10 +787,16 @@ If you would like to use any theme just for a single chart, you can use the
with alt.themes.enable('default'):
spec = chart.to_json()
+Built-in Themes
+~~~~~~~~~~~~~~~
Currently Altair does not offer many built-in themes, but we plan to add
more options in the future.
-See `Vega Theme Test`_ for an interactive demo of themes inherited from `Vega Themes`_.
+You can get a feel for the themes inherited from `Vega Themes`_ via *Vega-Altair Theme Test* below:
+
+.. altair-theme:: tests.altair_theme_test.alt_theme_test
+ :fold:
+ :summary: Show Vega-Altair Theme Test
Defining a Custom Theme
~~~~~~~~~~~~~~~~~~~~~~~
@@ -843,6 +849,13 @@ If you want to restore the default theme, use:
alt.themes.enable('default')
+When experimenting with your theme, you can use the code below to see how
+it translates across a range of charts/marks:
+
+.. altair-code-ref:: tests.altair_theme_test.alt_theme_test
+ :fold:
+ :summary: Show Vega-Altair Theme Test code
+
For more ideas on themes, see the `Vega Themes`_ repository.
@@ -889,5 +902,4 @@ The configured localization settings persist upon saving.
alt.renderers.set_embed_options(format_locale="en-US", time_format_locale="en-US")
.. _Vega Themes: https://github.com/vega/vega-themes/
-.. _`D3's localization support`: https://d3-wiki.readthedocs.io/zh-cn/master/Localization/
-.. _Vega Theme Test: https://vega.github.io/vega-themes/?renderer=canvas
\ No newline at end of file
+.. _`D3's localization support`: https://d3-wiki.readthedocs.io/zh-cn/master/Localization/
\ No newline at end of file
diff --git a/sphinxext/code_ref.py b/sphinxext/code_ref.py
new file mode 100644
index 000000000..6371713de
--- /dev/null
+++ b/sphinxext/code_ref.py
@@ -0,0 +1,330 @@
+"""Sphinx extension providing formatted code blocks, referencing some function."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Literal, get_args
+
+from docutils import nodes
+from docutils.parsers.rst import directives
+from sphinx.util.docutils import SphinxDirective
+from sphinx.util.parsing import nested_parse_to_nodes
+
+from altair.vegalite.v5.schema._typing import VegaThemes
+from tools.codemod import extract_func_def, extract_func_def_embed
+
+if TYPE_CHECKING:
+ import sys
+ from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ Iterable,
+ Iterator,
+ Mapping,
+ Sequence,
+ TypeVar,
+ Union,
+ )
+
+ from docutils.parsers.rst.states import RSTState, RSTStateMachine
+ from docutils.statemachine import StringList
+ from sphinx.application import Sphinx
+
+ if sys.version_info >= (3, 12):
+ from typing import TypeAliasType
+ else:
+ from typing_extensions import TypeAliasType
+ if sys.version_info >= (3, 10):
+ from typing import TypeAlias
+ else:
+ from typing_extensions import TypeAlias
+
+ T = TypeVar("T")
+ OneOrIter = TypeAliasType("OneOrIter", Union[T, Iterable[T]], type_params=(T,))
+
+_OutputShort: TypeAlias = Literal["code", "plot"]
+_OutputLong: TypeAlias = Literal["code-block", "altair-plot"]
+_OUTPUT_REMAP: Mapping[_OutputShort, _OutputLong] = {
+ "code": "code-block",
+ "plot": "altair-plot",
+}
+_Option: TypeAlias = Literal["output", "fold", "summary"]
+
+_PYSCRIPT_URL_FMT = "https://pyscript.net/releases/{0}/core.js"
+_PYSCRIPT_VERSION = "2024.10.1"
+_PYSCRIPT_URL = _PYSCRIPT_URL_FMT.format(_PYSCRIPT_VERSION)
+
+
+def validate_output(output: Any) -> _OutputLong:
+ output = output.strip().lower()
+ if output not in {"plot", "code"}:
+ msg = f":output: option must be one of {get_args(_OutputShort)!r}"
+ raise TypeError(msg)
+ else:
+ short: _OutputShort = output
+ return _OUTPUT_REMAP[short]
+
+
+def validate_packages(packages: Any) -> str:
+ if packages is None:
+ return '["altair"]'
+ else:
+ split = [pkg.strip() for pkg in packages.split(",")]
+ if len(split) == 1:
+ return f'["{split[0]}"]'
+ else:
+ return f'[{",".join(split)}]'
+
+
+def raw_html(text: str, /) -> nodes.raw:
+ return nodes.raw("", text, format="html")
+
+
+def maybe_details(
+ parsed: Iterable[nodes.Node], options: dict[_Option, Any], *, default_summary: str
+) -> Sequence[nodes.Node]:
+ """
+ Wrap ``parsed`` in a folding `details`_ block if requested.
+
+ Parameters
+ ----------
+ parsed
+ Target nodes that have been processed.
+ options
+ Optional arguments provided to ``.. altair-code-ref::``.
+
+ .. note::
+ If no relevant options are specified,
+ ``parsed`` is returned unchanged.
+
+ default_summary
+ Label text used when **only** specifying ``:fold:``.
+
+ .. _details:
+ https://developer.mozilla.org/en-US/docs/Web/HTML/Element/details
+ """
+
+ def gen() -> Iterator[nodes.Node]:
+ if {"fold", "summary"}.isdisjoint(options.keys()):
+ yield from parsed
+ else:
+ summary = options.get("summary", default_summary)
+ yield raw_html(f"
{summary}
")
+ yield from parsed
+ yield raw_html("
")
+
+ return list(gen())
+
+
+def theme_names() -> tuple[Sequence[str], Sequence[str]]:
+ names: set[VegaThemes] = set(get_args(VegaThemes))
+ carbon = {nm for nm in names if nm.startswith("carbon")}
+ return ["default", *sorted(names - carbon)], sorted(carbon)
+
+
+def option(label: str, value: str | None = None, /) -> nodes.raw:
+ s = f"\n")
+
+
+def optgroup(label: str, *options: OneOrIter[nodes.raw]) -> Iterator[nodes.raw]:
+ yield raw_html(f"\n")
+
+
+def dropdown(
+ id: str, label: str | None, extra_select: str, *options: OneOrIter[nodes.raw]
+) -> Iterator[nodes.raw]:
+ if label:
+ yield raw_html(f"\n")
+ select_text = f"\n")
+
+
+def pyscript(
+ packages: str, target_div_id: str, loading_label: str, py_code: str
+) -> Iterator[nodes.raw]:
+ PY = "py"
+ LB, RB = "{", "}"
+ packages = f""""packages":{packages}"""
+ yield raw_html(f"{loading_label}
\n")
+ yield raw_html(f"\n")
+
+
+def _before_code(refresh_name: str, select_id: str, target_div_id: str) -> str:
+ INDENT = " " * 4
+ return (
+ f"from js import document\n"
+ f"from pyscript import display\n"
+ f"import altair as alt\n\n"
+ f"def {refresh_name}(*args):\n"
+ f"{INDENT}selected = document.getElementById({select_id!r}).value\n"
+ f"{INDENT}alt.renderers.set_embed_options(theme=selected)\n"
+ f"{INDENT}display(chart, append=False, target={target_div_id!r})\n"
+ )
+
+
+class ThemeDirective(SphinxDirective):
+ """
+ Theme preview directive.
+
+ Similar to ``CodeRefDirective``, but uses `PyScript`_ to access the browser.
+
+ .. _PyScript:
+ https://pyscript.net/
+ """
+
+ has_content: ClassVar[Literal[False]] = False
+ required_arguments: ClassVar[Literal[1]] = 1
+ option_spec = {
+ "packages": validate_packages,
+ "dropdown-label": directives.unchanged,
+ "loading-label": directives.unchanged,
+ "fold": directives.flag,
+ "summary": directives.unchanged_required,
+ }
+
+ def run(self) -> Sequence[nodes.Node]:
+ results: list[nodes.Node] = []
+ SELECT_ID = "embed_theme"
+ REFRESH_NAME = "apply_embed_input"
+ TARGET_DIV_ID = "render_altair"
+ standard_names, carbon_names = theme_names()
+
+ qual_name = self.arguments[0]
+ module_name, func_name = qual_name.rsplit(".", 1)
+ dropdown_label = self.options.get("dropdown-label", "Select theme:")
+ loading_label = self.options.get("loading-label", "loading...")
+ packages: str = self.options.get("packages", validate_packages(None))
+
+ results.append(raw_html("\n"))
+ results.extend(
+ dropdown(
+ SELECT_ID,
+ dropdown_label,
+ f"py-input={REFRESH_NAME!r}",
+ (option(nm) for nm in standard_names),
+ optgroup("Carbon", (option(nm) for nm in carbon_names)),
+ )
+ )
+ py_code = extract_func_def_embed(
+ module_name,
+ func_name,
+ before=_before_code(REFRESH_NAME, SELECT_ID, TARGET_DIV_ID),
+ after=f"{REFRESH_NAME}()",
+ assign_to="chart",
+ indent=4,
+ )
+ results.extend(
+ pyscript(packages, TARGET_DIV_ID, loading_label, py_code=py_code)
+ )
+ results.append(raw_html("
\n"))
+ return maybe_details(
+ results, self.options, default_summary="Show Vega-Altair Theme Test"
+ )
+
+
+class PyScriptDirective(SphinxDirective):
+ """Placeholder for non-theme related directive."""
+
+ has_content: ClassVar[Literal[False]] = False
+ option_spec = {"packages": directives.unchanged}
+
+ def run(self) -> Sequence[nodes.Node]:
+ raise NotImplementedError
+
+
+class CodeRefDirective(SphinxDirective):
+ """
+ Formatted code block, referencing the contents of a function definition.
+
+ Options:
+
+ .. altair-code-ref::
+ :output: [code, plot]
+ :fold: flag
+ :summary: str
+
+ Examples
+ --------
+ Reference a function, generating a code block:
+
+ .. altair-code-ref:: package.module.function
+
+ Wrap the code block in a collapsible `details`_ tag:
+
+ .. altair-code-ref:: package.module.function
+ :fold:
+
+ Override default ``"Show code"`` `details`_ summary:
+
+ .. altair-code-ref:: package.module.function
+ :fold:
+ :summary: Look here!
+
+ Use `altair-plot`_ instead of a code block:
+
+ .. altair-code-ref:: package.module.function
+ :output: plot
+
+ .. note::
+ Using `altair-plot`_ currently ignores the other options.
+
+ .. _details:
+ https://developer.mozilla.org/en-US/docs/Web/HTML/Element/details
+ .. _altair-plot:
+ https://github.com/vega/sphinxext-altair
+ """
+
+ has_content: ClassVar[Literal[False]] = False
+ required_arguments: ClassVar[Literal[1]] = 1
+ option_spec: ClassVar[dict[_Option, Callable[[str], Any]]] = {
+ "output": validate_output,
+ "fold": directives.flag,
+ "summary": directives.unchanged_required,
+ }
+
+ def __init__(
+ self,
+ name: str,
+ arguments: list[str],
+ options: dict[_Option, Any],
+ content: StringList,
+ lineno: int,
+ content_offset: int,
+ block_text: str,
+ state: RSTState,
+ state_machine: RSTStateMachine,
+ ) -> None:
+ super().__init__(name, arguments, options, content, lineno, content_offset, block_text, state, state_machine) # fmt: skip
+ self.options: dict[_Option, Any]
+
+ def run(self) -> Sequence[nodes.Node]:
+ qual_name = self.arguments[0]
+ module_name, func_name = qual_name.rsplit(".", 1)
+ output: _OutputLong = self.options.get("output", "code-block")
+ content = extract_func_def(module_name, func_name, output=output)
+ parsed = nested_parse_to_nodes(self.state, content)
+ return maybe_details(parsed, self.options, default_summary="Show code")
+
+
+def setup(app: Sphinx) -> None:
+ app.add_directive_to_domain("py", "altair-code-ref", CodeRefDirective)
+ app.add_js_file(_PYSCRIPT_URL, loading_method="defer", type="module")
+ # app.add_directive("altair-pyscript", PyScriptDirective)
+ app.add_directive("altair-theme", ThemeDirective)
diff --git a/tests/altair_theme_test.py b/tests/altair_theme_test.py
new file mode 100644
index 000000000..b9114baec
--- /dev/null
+++ b/tests/altair_theme_test.py
@@ -0,0 +1,141 @@
+# ruff: noqa: E711
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from altair.typing import ChartType
+
+
+def alt_theme_test() -> ChartType:
+ import altair as alt
+
+ VEGA_DATASETS = "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/"
+ us_10m = f"{VEGA_DATASETS}us-10m.json"
+ unemployment = f"{VEGA_DATASETS}unemployment.tsv"
+ movies = f"{VEGA_DATASETS}movies.json"
+ barley = f"{VEGA_DATASETS}barley.json"
+ iowa_electricity = f"{VEGA_DATASETS}iowa-electricity.csv"
+ common_data = alt.InlineData(
+ [
+ {"Index": 1, "Value": 28, "Position": 1, "Category": "A"},
+ {"Index": 2, "Value": 55, "Position": 2, "Category": "A"},
+ {"Index": 3, "Value": 43, "Position": 3, "Category": "A"},
+ {"Index": 4, "Value": 91, "Position": 4, "Category": "A"},
+ {"Index": 5, "Value": 81, "Position": 5, "Category": "A"},
+ {"Index": 6, "Value": 53, "Position": 6, "Category": "A"},
+ {"Index": 7, "Value": 19, "Position": 1, "Category": "B"},
+ {"Index": 8, "Value": 87, "Position": 2, "Category": "B"},
+ {"Index": 9, "Value": 52, "Position": 3, "Category": "B"},
+ {"Index": 10, "Value": 48, "Position": 4, "Category": "B"},
+ {"Index": 11, "Value": 24, "Position": 5, "Category": "B"},
+ {"Index": 12, "Value": 49, "Position": 6, "Category": "B"},
+ {"Index": 13, "Value": 87, "Position": 1, "Category": "C"},
+ {"Index": 14, "Value": 66, "Position": 2, "Category": "C"},
+ {"Index": 15, "Value": 17, "Position": 3, "Category": "C"},
+ {"Index": 16, "Value": 27, "Position": 4, "Category": "C"},
+ {"Index": 17, "Value": 68, "Position": 5, "Category": "C"},
+ {"Index": 18, "Value": 16, "Position": 6, "Category": "C"},
+ ]
+ )
+
+ HEIGHT_SMALL = 140
+ STANDARD = 180
+ WIDTH_GEO = int(STANDARD * 1.667)
+
+ bar = (
+ alt.Chart(common_data, height=HEIGHT_SMALL, width=STANDARD, title="Bar")
+ .mark_bar()
+ .encode(
+ x=alt.X("Index:O").axis(offset=1), y=alt.Y("Value:Q"), tooltip="Value:Q"
+ )
+ .transform_filter(alt.datum["Index"] <= 9)
+ )
+ line = (
+ alt.Chart(common_data, height=HEIGHT_SMALL, width=STANDARD, title="Line")
+ .mark_line()
+ .encode(
+ x=alt.X("Position:O").axis(grid=False),
+ y=alt.Y("Value:Q").axis(grid=False),
+ color=alt.Color("Category:N").legend(None),
+ tooltip=["Index:O", "Value:Q", "Position:O", "Category:N"],
+ )
+ )
+ point_shape = (
+ alt.Chart(
+ common_data, height=HEIGHT_SMALL, width=STANDARD, title="Point (Shape)"
+ )
+ .mark_point()
+ .encode(
+ x=alt.X("Position:O").axis(grid=False),
+ y=alt.Y("Value:Q").axis(grid=False),
+ shape=alt.Shape("Category:N").legend(None),
+ color=alt.Color("Category:N").legend(None),
+ tooltip=["Index:O", "Value:Q", "Position:O", "Category:N"],
+ )
+ )
+ point = (
+ alt.Chart(movies, height=STANDARD, width=STANDARD, title="Point")
+ .mark_point(tooltip=True)
+ .transform_filter(alt.datum["IMDB_Rating"] != None)
+ .transform_filter(
+ alt.FieldRangePredicate("Release_Date", [None, 2019], timeUnit="year")
+ )
+ .transform_joinaggregate(Average_Rating="mean(IMDB_Rating)")
+ .transform_calculate(
+ Rating_Delta=alt.datum["IMDB_Rating"] - alt.datum.Average_Rating
+ )
+ .encode(
+ x=alt.X("Release_Date:T").title("Release Date"),
+ y=alt.Y("Rating_Delta:Q").title("Rating Delta"),
+ color=alt.Color("Rating_Delta:Q").title("Rating Delta").scale(domainMid=0),
+ )
+ )
+ bar_stack = (
+ alt.Chart(barley, height=STANDARD, width=STANDARD, title="Bar (Stacked)")
+ .mark_bar(tooltip=True)
+ .encode(
+ x="sum(yield):Q",
+ y=alt.Y("variety:N"),
+ color=alt.Color("site:N").legend(orient="bottom", columns=2),
+ )
+ )
+ area = (
+ alt.Chart(iowa_electricity, height=STANDARD, width=STANDARD, title="Area")
+ .mark_area(tooltip=True)
+ .encode(
+ x=alt.X("year:T").title("Year"),
+ y=alt.Y("net_generation:Q")
+ .title("Share of net generation")
+ .stack("normalize")
+ .axis(format=".0%"),
+ color=alt.Color("source:N")
+ .title("Electricity source")
+ .legend(orient="bottom", columns=2),
+ )
+ )
+ geoshape = (
+ alt.Chart(
+ alt.topo_feature(us_10m, "counties"),
+ height=STANDARD,
+ width=WIDTH_GEO,
+ title=alt.Title("Geoshape", subtitle="Unemployment rate per county"),
+ )
+ .mark_geoshape(tooltip=True)
+ .encode(color="rate:Q")
+ .transform_lookup(
+ "id", alt.LookupData(alt.UrlData(unemployment), "id", ["rate"])
+ )
+ .project(type="albersUsa")
+ )
+
+ compound_chart = (
+ (bar | line | point_shape) & (point | bar_stack) & (area | geoshape)
+ ).properties(
+ title=alt.Title(
+ "Vega-Altair Theme Test",
+ fontSize=20,
+ subtitle="Adapted from https://vega.github.io/vega-themes/",
+ )
+ )
+ return compound_chart
diff --git a/tools/codemod.py b/tools/codemod.py
new file mode 100644
index 000000000..8281e0732
--- /dev/null
+++ b/tools/codemod.py
@@ -0,0 +1,453 @@
+# ruff: noqa: D418
+from __future__ import annotations
+
+import ast
+import subprocess
+import sys
+import textwrap
+import warnings
+from collections import deque
+from importlib.util import find_spec
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Iterable, TypeVar, overload
+
+if sys.version_info >= (3, 12):
+ from typing import Protocol, TypeAliasType
+else:
+ from typing_extensions import Protocol, TypeAliasType
+
+if TYPE_CHECKING:
+ if sys.version_info >= (3, 11):
+ from typing import LiteralString
+ else:
+ from typing_extensions import LiteralString
+
+ from typing import ClassVar, Iterator, Literal
+
+
+__all__ = ["extract_func_def", "extract_func_def_embed", "ruff", "ruff_inline_docs"]
+
+T = TypeVar("T")
+OneOrIterV = TypeAliasType(
+ "OneOrIterV",
+ "T | Iterable[T] | Iterable[OneOrIterV[T]]",
+ type_params=(T,),
+)
+_Code = OneOrIterV[str]
+
+
+def parse_module(name: str, /) -> ast.Module:
+ """
+ Find absolute path and parse module into an ast.
+
+ Use regular dotted import style, no `.py` suffix.
+
+ Acceptable ``name``:
+
+ altair.package.subpackage.etc
+ tools.____
+ tests.____
+ doc.____
+ sphinxext.____
+ """
+ if (spec := find_spec(name)) and (origin := spec.origin):
+ return ast.parse(Path(origin).read_bytes())
+ else:
+ raise FileNotFoundError(name)
+
+
+if sys.version_info >= (3, 9):
+
+ def unparse(obj: ast.AST, /) -> str:
+ return ast.unparse(obj)
+else:
+
+ def unparse(obj: ast.AST, /) -> str:
+ """
+ Added in ``3.9``.
+
+ https://docs.python.org/3/library/ast.html#ast.unparse
+ """
+ # HACK: Will only be used during build/docs
+ # - This branch is just to satisfy linters
+ msg = f"Called `ast.unparse()` on {sys.version_info!r}\nFunction not available before {(3, 9)!r}"
+ warnings.warn(msg, ImportWarning, stacklevel=2)
+ return ""
+
+
+def find_func_def(mod: ast.Module, fn_name: str, /) -> ast.FunctionDef:
+ """
+ Return a function node matching ``fn_name``.
+
+ Notes
+ -----
+ Provides some extra type safety, over::
+
+ ast.Module.body: list[ast.stmt]
+ """
+ for stmt in mod.body:
+ if isinstance(stmt, ast.FunctionDef) and stmt.name == fn_name:
+ return stmt
+ else:
+ continue
+ msg = f"Found no function named {fn_name!r}"
+ raise NotImplementedError(msg)
+
+
+def validate_body(fn: ast.FunctionDef, /) -> tuple[list[ast.stmt], ast.expr]:
+ """
+ Ensure function has inlined imports and a return statement.
+
+ Returns::
+
+ (ast.FunctionDef.body[:-1], ast.Return.value)
+ """
+ body = fn.body
+ first = body[0]
+ if not isinstance(first, (ast.Import, ast.ImportFrom)):
+ msg = (
+ f"First statement in function must be an import, "
+ f"got {type(first).__name__!r}\n\n"
+ f"{unparse(first)!r}"
+ )
+ raise TypeError(msg)
+ last = body.pop()
+ if not isinstance(last, ast.Return) or last.value is None:
+ body.append(last)
+ msg = (
+ f"Last statement in function must return an expression, "
+ f"got {type(last).__name__!r}\n\n"
+ f"{unparse(last)!r}"
+ )
+ raise TypeError(msg)
+ else:
+ return body, last.value
+
+
+def iter_flatten(*elements: _Code) -> Iterator[str]:
+ for el in elements:
+ if not isinstance(el, str) and isinstance(el, Iterable):
+ yield from iter_flatten(*el)
+ elif isinstance(el, str):
+ yield el
+ else:
+ msg = (
+ f"Expected all elements to eventually flatten to ``str``, "
+ f"but got: {type(el).__name__!r}\n\n"
+ f"{el!r}"
+ )
+ raise TypeError(msg)
+
+
+def iter_func_def_unparse(
+ module_name: str,
+ func_name: str,
+ /,
+ *,
+ return_transform: Literal["assign"] | None = None,
+ assign_to: str = "chart",
+) -> Iterator[str]:
+ # Planning to add pyscript code before/after
+ # Then add `ruff check` to clean up duplicate imports (on the whole thing)
+ # Optional: assign the return to `assign_to`
+ # - Allows writing modular code that doesn't depend on the variable names in the original function
+ mod = parse_module(module_name)
+ fn = find_func_def(mod, func_name)
+ body, ret = validate_body(fn)
+ for node in body:
+ yield unparse(node)
+ yield ""
+ ret_value = unparse(ret)
+ if return_transform is None:
+ yield ret_value
+ elif return_transform == "assign":
+ yield f"{assign_to} = {ret_value}"
+ else:
+ msg = f"{return_transform=}"
+ raise NotImplementedError(msg)
+
+
+def extract_func_def(
+ module_name: str,
+ func_name: str,
+ *,
+ format: bool = True,
+ output: Literal["altair-plot", "code-block", "str"] = "str",
+) -> str:
+ """
+ Extract the contents of a function for use as a code block.
+
+ Parameters
+ ----------
+ module_name
+ Absolute, dotted import style.
+ func_name
+ Name of function in ``module_name``.
+ format
+ Run through ``ruff format`` before returning.
+ output
+ Optionally, return embedded in an `rst` directive.
+
+ Notes
+ -----
+ - Functions must define all imports inline, to ensure they are propagated
+ - Must end with a single return statement
+
+ Warning
+ -------
+ Requires ``python>=3.9`` for `ast.unparse`_
+
+ Examples
+ --------
+ Transform the contents of a function into a code block::
+
+ >>> extract_func_def("tests.altair_theme_test", "alt_theme_test", output="code-block") # doctest: +SKIP
+
+ .. _ast.unparse:
+ https://docs.python.org/3.9/library/ast.html#ast.unparse
+ """
+ if output not in {"altair-plot", "code-block", "str"}:
+ raise TypeError(output)
+ it = iter_func_def_unparse(module_name, func_name)
+ s = ruff_inline_docs.format(it) if format else "\n".join(it)
+ if output == "str":
+ return s
+ else:
+ return f".. {output}::\n\n{textwrap.indent(s, ' ' * 4)}\n"
+
+
+def extract_func_def_embed(
+ module_name: str,
+ func_name: str,
+ /,
+ before: _Code | None = None,
+ after: _Code | None = None,
+ assign_to: str = "chart",
+ indent: int | None = None,
+) -> str:
+ """
+ Extract the contents of a function, wrapping with ``before`` and ``after``.
+
+ The resulting code block is run through ``ruff`` to deduplicate imports
+ and apply consistent formatting.
+
+ Parameters
+ ----------
+ module_name
+ Absolute, dotted import style.
+ func_name
+ Name of function in ``module_name``.
+ before
+ Code inserted before ``func_name``.
+ after
+ Code inserted after ``func_name``.
+ assign_to
+ Variable name to use as the result of ``func_name``.
+
+ .. note::
+ Allows the ``after`` block to use a consistent reference.
+ indent
+ Optionally, prefix ``indent * " "`` to final block.
+
+ .. note::
+ Occurs **after** formatting, will not contribute to line length wrap.
+ """
+ if before is None and after is None:
+ msg = (
+ f"At least one additional code fragment should be provided, but:\n"
+ f"{before=}, {after=}\n\n"
+ f"Use {extract_func_def.__qualname__!r} instead."
+ )
+ warnings.warn(msg, UserWarning, stacklevel=2)
+ unparsed = iter_func_def_unparse(
+ module_name, func_name, return_transform="assign", assign_to=assign_to
+ )
+ parts = [p for p in (before, unparsed, after) if p is not None]
+ formatted = ruff_inline_docs(parts)
+ return textwrap.indent(formatted, " " * indent) if indent else formatted
+
+
+class CodeMod(Protocol):
+ def __call__(self, *code: _Code) -> str:
+ """
+ Transform some input into a single block of modified code.
+
+ Parameters
+ ----------
+ *code
+ Arbitrarily nested code fragments.
+ """
+ ...
+
+ def _join(self, code: _Code, *, sep: str = "\n") -> str:
+ """
+ Concatenate any number of code fragments.
+
+ All nested groups are unwrapped into a flat iterable.
+ """
+ return sep.join(iter_flatten(code))
+
+
+class Ruff(CodeMod):
+ """
+ Run `ruff`_ commands against code fragments or files.
+
+ By default, uses the same config as `pyproject.toml`_.
+
+ Parameters
+ ----------
+ *extend_select
+ `rule codes`_ to use **on top of** the default config.
+ ignore
+ `rule codes`_ to `ignore`_.
+ skip_magic_traling_comma
+ Enables `skip-magic-trailing-comma`_ during formatting.
+
+ .. note::
+
+ Only use on code that is changing indent-level
+ (e.g. unwrapping function contents).
+
+ .. _ruff:
+ https://docs.astral.sh/ruff/
+ .. _pyproject.toml:
+ https://github.com/vega/altair/blob/main/pyproject.toml
+ .. _rule codes:
+ https://docs.astral.sh/ruff/rules/
+ .. _ignore:
+ https://docs.astral.sh/ruff/settings/#lint_ignore
+ .. _skip-magic-trailing-comma:
+ https://docs.astral.sh/ruff/settings/#format_skip-magic-trailing-comma
+ """
+
+ _stdin_args: ClassVar[tuple[LiteralString, ...]] = (
+ "--stdin-filename",
+ "placeholder.py",
+ )
+ _check_args: ClassVar[tuple[LiteralString, ...]] = ("--fix",)
+
+ def __init__(
+ self,
+ *extend_select: str,
+ ignore: OneOrIterV[str] | None = None,
+ skip_magic_traling_comma: bool = False,
+ ) -> None:
+ self.check_args: deque[str] = deque(self._check_args)
+ self.format_args: deque[str] = deque()
+ for c in extend_select:
+ self.check_args.extend(("--extend-select", c))
+ if ignore is not None:
+ self.check_args.extend(
+ ("--ignore", ",".join(s for s in iter_flatten(ignore)))
+ )
+ if skip_magic_traling_comma:
+ self.format_args.extend(
+ ("--config", "format.skip-magic-trailing-comma = true")
+ )
+
+ def write_lint_format(self, fp: Path, code: _Code, /) -> None:
+ """
+ Combined steps of writing, `ruff check`, `ruff format`.
+
+ Parameters
+ ----------
+ fp
+ Target file to write to
+ code
+ Some (potentially) nested code fragments.
+
+ Notes
+ -----
+ - `fp` is written to first, as the size before formatting will be the smallest
+ - Better utilizes `ruff` performance, rather than `python` str and io
+ """
+ self.check(fp, code)
+ self.format(fp)
+
+ @overload
+ def check(self, *code: _Code, decode: Literal[True] = ...) -> str:
+ """Fixes violations and returns fixed code."""
+
+ @overload
+ def check(self, *code: _Code, decode: Literal[False]) -> bytes:
+ """
+ ``decode=False`` will return as ``bytes``.
+
+ Helpful if piping to another command.
+ """
+
+ @overload
+ def check(self, _write_to: Path, /, *code: _Code) -> None:
+ """
+ ``code`` is joined, written to provided path and then checked.
+
+ No input returned.
+ """
+
+ def check(self, *code: Any, decode: bool = True) -> str | bytes | None:
+ """
+ Check and fix ``ruff`` rule violations.
+
+ All cases will join ``code`` to a single ``str``.
+ """
+ base = "ruff", "check"
+ if isinstance(code[0], Path):
+ fp = code[0]
+ fp.write_text(self._join(code[1:]), encoding="utf-8")
+ subprocess.run((*base, fp, *self.check_args), check=True)
+ return None
+ r = subprocess.run(
+ (*base, *self.check_args, *self._stdin_args),
+ input=self._join(code).encode(),
+ check=True,
+ capture_output=True,
+ )
+ return r.stdout.decode() if decode else r.stdout
+
+ @overload
+ def format(self, *code: _Code) -> str:
+ """Format arbitrarily nested input as a single block."""
+
+ @overload
+ def format(self, _target_file: Path, /, *code: None) -> None:
+ """
+ Format an existing file.
+
+ Running on `win32` after writing lines will ensure ``LF`` is used, and not ``CRLF``:
+
+ ruff format --diff --check _target_file
+ """
+
+ @overload
+ def format(self, _encoded_result: bytes, /, *code: None) -> str:
+ """Format the raw result of ``ruff.check``."""
+
+ def format(self, *code: Any) -> str | None:
+ """
+ Format some input code, or an existing file.
+
+ Returns decoded result unless formatting an existing file.
+ """
+ base = "ruff", "format"
+ if len(code) == 1 and isinstance(code[0], Path):
+ subprocess.run((*base, code[0], *self.format_args), check=True)
+ return None
+ encoded = (
+ code[0]
+ if len(code) == 1 and isinstance(code[0], bytes)
+ else self._join(code).encode()
+ )
+ r = subprocess.run(
+ (*base, *self.format_args, *self._stdin_args),
+ input=encoded,
+ check=True,
+ capture_output=True,
+ )
+ return r.stdout.decode()
+
+ def __call__(self, *code: _Code) -> str:
+ return self.format(self.check(code, decode=False))
+
+
+ruff_inline_docs = Ruff(ignore="E711", skip_magic_traling_comma=True)
+ruff = Ruff()
diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py
index 3b120bc77..0db8772e8 100644
--- a/tools/generate_schema_wrapper.py
+++ b/tools/generate_schema_wrapper.py
@@ -20,6 +20,7 @@
sys.path.insert(0, str(Path.cwd()))
+from tools.codemod import ruff
from tools.markup import rst_syntax_for_class
from tools.schemapi import CodeSnippet, SchemaInfo, arg_kwds, arg_required_kwds, codegen
from tools.schemapi.utils import (
@@ -31,8 +32,6 @@
import_typing_extensions,
indent_docstring,
resolve_references,
- ruff_format_py,
- ruff_write_lint_format_str,
spell_literal,
)
from tools.vega_expr import write_expr_module
@@ -548,7 +547,7 @@ def copy_schemapi_util() -> None:
dest.write(HEADER_COMMENT)
dest.writelines(source.readlines())
if sys.platform == "win32":
- ruff_format_py(destination_fp)
+ ruff.format(destination_fp)
def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None:
@@ -1024,7 +1023,7 @@ def vegalite_main(skip_download: bool = False) -> None:
f"SCHEMA_VERSION = '{version}'\n",
f"SCHEMA_URL = {schema_url(version)!r}\n",
]
- ruff_write_lint_format_str(outfile, content)
+ ruff.write_lint_format(outfile, content)
TypeAliasTracer.update_aliases(("Map", "Mapping[str, Any]"))
@@ -1106,7 +1105,7 @@ def vegalite_main(skip_download: bool = False) -> None:
# Write the pre-generated modules
for fp, contents in files.items():
print(f"Writing\n {schemafile!s}\n ->{fp!s}")
- ruff_write_lint_format_str(fp, contents)
+ ruff.write_lint_format(fp, contents)
def generate_encoding_artifacts(
diff --git a/tools/schemapi/utils.py b/tools/schemapi/utils.py
index 6bc7b1f4b..351c355d0 100644
--- a/tools/schemapi/utils.py
+++ b/tools/schemapi/utils.py
@@ -4,7 +4,6 @@
import json
import re
-import subprocess
import sys
import textwrap
import urllib.parse
@@ -19,13 +18,13 @@
Iterator,
Literal,
Mapping,
- MutableSequence,
Sequence,
TypeVar,
Union,
overload,
)
+from tools.codemod import ruff
from tools.markup import RSTParseVegaLite, rst_syntax_for_class
from tools.schemapi.schemapi import _resolve_references as resolve_references
@@ -91,13 +90,6 @@ class _TypeAliasTracer:
A format specifier to produce the `TypeAlias` name.
Will be provided a `SchemaInfo.title` as a single positional argument.
- *ruff_check
- Optional [ruff rule codes](https://docs.astral.sh/ruff/rules/),
- each prefixed with `--select ` and follow a `ruff check --fix ` call.
-
- If not provided, uses `[tool.ruff.lint.select]` from `pyproject.toml`.
- ruff_format
- Optional argument list supplied to [ruff format](https://docs.astral.sh/ruff/formatter/#ruff-format)
Attributes
----------
@@ -111,12 +103,7 @@ class _TypeAliasTracer:
Prefined import statements to appear at beginning of module.
"""
- def __init__(
- self,
- fmt: str = "{}_T",
- *ruff_check: str,
- ruff_format: Sequence[str] | None = None,
- ) -> None:
+ def __init__(self, fmt: str = "{}_T") -> None:
self.fmt: str = fmt
self._literals: dict[str, str] = {}
self._literals_invert: dict[str, str] = {}
@@ -135,10 +122,6 @@ def __init__(
import_typing_extensions((3, 10), "TypeAlias"),
import_typing_extensions((3, 9), "Annotated", "get_args"),
)
- self._cmd_check: list[str] = ["--fix"]
- self._cmd_format: Sequence[str] = ruff_format or ()
- for c in ruff_check:
- self._cmd_check.extend(("--extend-select", c))
def _update_literals(self, name: str, tp: str, /) -> None:
"""Produces an inverted index, to reuse a `Literal` when `SchemaInfo.title` is empty."""
@@ -223,13 +206,6 @@ def write_module(
extra
`tools.generate_schema_wrapper.TYPING_EXTRA`.
"""
- ruff_format: MutableSequence[str | Path] = ["ruff", "format", fp]
- if self._cmd_format:
- ruff_format.extend(self._cmd_format)
- commands: tuple[Sequence[str | Path], ...] = (
- ["ruff", "check", fp, *self._cmd_check],
- ruff_format,
- )
static = (header, "\n", *self._imports, "\n\n")
self.update_aliases(*sorted(self._literals.items(), key=itemgetter(0)))
all_ = [*iter(self._aliases), *extra_all]
@@ -238,10 +214,7 @@ def write_module(
[f"__all__ = {all_}", "\n\n", extra],
self.generate_aliases(),
)
- fp.write_text("\n".join(it), encoding="utf-8")
- for cmd in commands:
- r = subprocess.run(cmd, check=True)
- r.check_returncode()
+ ruff.write_lint_format(fp, it)
@property
def n_entries(self) -> int:
@@ -997,49 +970,6 @@ def unwrap_literal(tp: str, /) -> str:
return re.sub(r"Literal\[(.+)\]", r"\g<1>", tp)
-def ruff_format_py(fp: Path, /, *extra_args: str) -> None:
- """
- Format an existing file.
-
- Running on `win32` after writing lines will ensure "lf" is used before:
- ```bash
- ruff format --diff --check .
- ```
- """
- cmd: MutableSequence[str | Path] = ["ruff", "format", fp]
- if extra_args:
- cmd.extend(extra_args)
- r = subprocess.run(cmd, check=True)
- r.check_returncode()
-
-
-def ruff_write_lint_format_str(
- fp: Path, code: str | Iterable[str], /, *, encoding: str = "utf-8"
-) -> None:
- """
- Combined steps of writing, `ruff check`, `ruff format`.
-
- Notes
- -----
- - `fp` is written to first, as the size before formatting will be the smallest
- - Better utilizes `ruff` performance, rather than `python` str and io
- - `code` is no longer bound to `list`
- - Encoding set as default
- - `I001/2` are `isort` rules, to sort imports.
- """
- commands: Iterable[Sequence[str | Path]] = (
- ["ruff", "check", fp, "--fix"],
- ["ruff", "check", fp, "--fix", "--select", "I001", "--select", "I002"],
- )
- if not isinstance(code, str):
- code = "\n".join(code)
- fp.write_text(code, encoding=encoding)
- for cmd in commands:
- r = subprocess.run(cmd, check=True)
- r.check_returncode()
- ruff_format_py(fp)
-
-
def import_type_checking(*imports: str) -> str:
"""Write an `if TYPE_CHECKING` block."""
imps = "\n".join(f" {s}" for s in imports)
@@ -1066,7 +996,7 @@ def import_typing_extensions(
)
-TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T", "I001", "I002")
+TypeAliasTracer: _TypeAliasTracer = _TypeAliasTracer("{}_T")
"""An instance of `_TypeAliasTracer`.
Collects a cache of unique `Literal` types used globally.
diff --git a/tools/update_init_file.py b/tools/update_init_file.py
index c1831093a..0592032f0 100644
--- a/tools/update_init_file.py
+++ b/tools/update_init_file.py
@@ -8,7 +8,7 @@
from pathlib import Path
from typing import TYPE_CHECKING
-from tools.schemapi.utils import ruff_write_lint_format_str
+from tools.codemod import ruff
_TYPING_CONSTRUCTS = {
te.TypeAlias,
@@ -74,7 +74,7 @@ def update__all__variable() -> None:
]
# Write new version of altair/__init__.py
# Format file content with ruff
- ruff_write_lint_format_str(init_path, new_lines)
+ ruff.write_lint_format(init_path, new_lines)
def relevant_attributes(namespace: dict[str, t.Any], /) -> list[str]:
diff --git a/tools/vega_expr.py b/tools/vega_expr.py
index ce87cb2fb..66d6287fb 100644
--- a/tools/vega_expr.py
+++ b/tools/vega_expr.py
@@ -29,12 +29,10 @@
overload,
)
+from tools.codemod import ruff
from tools.markup import RSTParse, Token, read_ast_tokens
from tools.markup import RSTRenderer as _RSTRenderer
from tools.schemapi.schemapi import SchemaBase as _SchemaBase
-from tools.schemapi.utils import (
- ruff_write_lint_format_str as _ruff_write_lint_format_str,
-)
if TYPE_CHECKING:
import sys
@@ -977,4 +975,4 @@ def write_expr_module(version: str, output: Path, *, header: str) -> None:
[MODULE_POST],
)
print(f"Generating\n {url!s}\n ->{output!s}")
- _ruff_write_lint_format_str(output, contents)
+ ruff.write_lint_format(output, contents)