From d374ecafdab04d4453a7c8172a8c492fb5e80b45 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 27 May 2025 16:22:16 -0700 Subject: [PATCH] [WIP] Add "The stack" section to left nav The new section contains an overview of the libraries in the stack, as well as a page for each library with a brief description and outbound links for more information. --- docs/requirements.txt | 1 + docs/source/conf.py | 1 + docs/source/index.rst | 14 ++++++++++++++ docs/source/stack_chex.md | 1 + docs/source/stack_flax.md | 24 ++++++++++++++++++++++++ docs/source/stack_grain.md | 1 + docs/source/stack_jax.md | 18 ++++++++++++++++++ docs/source/stack_optax.md | 12 ++++++++++++ docs/source/stack_orbax_checkpoint.md | 1 + docs/source/stack_orbax_export.md | 1 + docs/source/stack_overview.md | 26 ++++++++++++++++++++++++++ 11 files changed, 100 insertions(+) create mode 100644 docs/source/stack_chex.md create mode 100644 docs/source/stack_flax.md create mode 100644 docs/source/stack_grain.md create mode 100644 docs/source/stack_jax.md create mode 100644 docs/source/stack_optax.md create mode 100644 docs/source/stack_orbax_checkpoint.md create mode 100644 docs/source/stack_orbax_export.md create mode 100644 docs/source/stack_overview.md diff --git a/docs/requirements.txt b/docs/requirements.txt index 5dc8941..f0e5cfe 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,6 +5,7 @@ myst-nb myst-parser[linkify] sphinx-book-theme sphinx-copybutton +sphinx-design # Packages required for notebook execution matplotlib diff --git a/docs/source/conf.py b/docs/source/conf.py index 36bc79a..7d69737 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,6 +17,7 @@ extensions = [ 'myst_nb', 'sphinx_copybutton', + 'sphinx_design', ] templates_path = ['_templates'] diff --git a/docs/source/index.rst b/docs/source/index.rst index 4160827..cd27e01 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,6 +12,20 @@ JAX AI Stack install getting_started +.. toctree:: + :hidden: + :caption: The stack + :maxdepth: 1 + + stack_overview + stack_jax + stack_flax + stack_optax + stack_orbax_checkpoint + stack_orbax_export + stack_grain + stack_chex + .. toctree:: :hidden: :caption: Tutorials diff --git a/docs/source/stack_chex.md b/docs/source/stack_chex.md new file mode 100644 index 0000000..4f3ba00 --- /dev/null +++ b/docs/source/stack_chex.md @@ -0,0 +1 @@ +# Chex: test utilities diff --git a/docs/source/stack_flax.md b/docs/source/stack_flax.md new file mode 100644 index 0000000..3d4c492 --- /dev/null +++ b/docs/source/stack_flax.md @@ -0,0 +1,24 @@ +# Flax NNX: neural nets + +Flax NNX provides **neural net functionality** on top of JAX, such as a module +abstraction and pre-defined layers, via a **Pythonic object-oriented API**. NNX +allows you to write stateful model code that can still take advantage of JAX's +function transforms and other features. + +NNX has native integration with [Optax](stack_optax). + +Main Flax NNX site: +**[flax.readthedocs.io{material-regular}`open_in_new`](https://flax.readthedocs.io/)** + +**If you'd like to learn more about NNX** beyond what's covered in the +[](getting_started) guide, we recommend starting with **[Flax +basics{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/nnx_basics.html)**. + +The Flax NNX docs cover many other useful topics including: + +* [Function + transforms{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/transforms.html) +* [Parallelism{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) +* [Performance + considerations{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/performance.html) +* And much more! diff --git a/docs/source/stack_grain.md b/docs/source/stack_grain.md new file mode 100644 index 0000000..d320a46 --- /dev/null +++ b/docs/source/stack_grain.md @@ -0,0 +1 @@ +# Grain: data loading diff --git a/docs/source/stack_jax.md b/docs/source/stack_jax.md new file mode 100644 index 0000000..9c9caff --- /dev/null +++ b/docs/source/stack_jax.md @@ -0,0 +1,18 @@ +# JAX: array computing + +JAX is the foundation of the JAX AI Stack! It provides **high-performance array +computing** functionality over accelerators via a simple **NumPy-like API and +function transformations**. + +Main JAX site: **[jax.dev{material-regular}`open_in_new`](https://jax.dev)** + +**If you'd like to learn more about JAX** beyond what's covered in the +[](getting_started) guide, we recommend starting with the **[JAX +tutorials{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/tutorials.html)**. + +The JAX docs cover many other useful topics including: + +* [Performance profiling{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/profiling.html) +* [Multi-host JAX programs{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/multi_process.html) +* [Custom GPU + TPU kernels with Pallas{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/pallas/index.html) +* And much more! diff --git a/docs/source/stack_optax.md b/docs/source/stack_optax.md new file mode 100644 index 0000000..b97d40e --- /dev/null +++ b/docs/source/stack_optax.md @@ -0,0 +1,12 @@ +# Optax: optimizers + +Optax provides **gradient processing and optimization** functionality on top of +JAX, including optimizers and losses. + +Main Optax site: +**[optax.readthedocs.io{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/index.html)** + +**If you'd like to learn more about Optax** beyond what's covered in the +[](getting_started) guide, we recommend starting with the **[Optax getting +started{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/getting_started.html)** +guide. diff --git a/docs/source/stack_orbax_checkpoint.md b/docs/source/stack_orbax_checkpoint.md new file mode 100644 index 0000000..36c2255 --- /dev/null +++ b/docs/source/stack_orbax_checkpoint.md @@ -0,0 +1 @@ +# Orbax: checkpointing diff --git a/docs/source/stack_orbax_export.md b/docs/source/stack_orbax_export.md new file mode 100644 index 0000000..a9f9ae7 --- /dev/null +++ b/docs/source/stack_orbax_export.md @@ -0,0 +1 @@ +# Orbax: model export diff --git a/docs/source/stack_overview.md b/docs/source/stack_overview.md new file mode 100644 index 0000000..160c9e1 --- /dev/null +++ b/docs/source/stack_overview.md @@ -0,0 +1,26 @@ +# Stack overview + +The JAX AI Stack is comprised of the following packages: + +* [JAX{material-regular}`open_in_new`](https://jax.dev): high-performance array + computing +* [Flax + NNX{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/): + object-oriented neural nets +* [Optax{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/index.html): + optimizers +* [Orbax{material-regular}`open_in_new`](https://orbax.readthedocs.io/en/latest/): + checkpointing and model export +* [Grain{material-regular}`open_in_new`](https://google-grain.readthedocs.io/en/latest/): + JAX-native data loading +* [Chex{material-regular}`open_in_new`](https://chex.readthedocs.io/en/latest/): + JAX test utilities + +The `jax-ai-stack` metapackage installs compatible versions of all of these +libraries, as well as shared compatible versions of shared dependencies. + +In addition, there is an optional `jax-ai-stack[tfds]` installation that +includes [TensorFlow +Datasets{material-regular}`open_in_new`](https://www.tensorflow.org/datasets), +for those who wish to use TFDS for data loading. This includes a compatible +version of TensorFlow as well.