From 437c7cdfb531d23b0f8972f7273d9b3dee9f833d Mon Sep 17 00:00:00 2001 From: Cole Haus Date: Fri, 1 Sep 2023 13:03:00 -0700 Subject: [PATCH] Fix documentation and acceptable values for `filter_pmap`'s 'donate' --- equinox/_jit.py | 12 ++++++++---- equinox/_vmap_pmap.py | 25 ++++++++++++++++--------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/equinox/_jit.py b/equinox/_jit.py index c6586609..633fe99a 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -2,7 +2,7 @@ import inspect import warnings from collections.abc import Callable -from typing import Any, overload, TypeVar +from typing import Any, Literal, overload, TypeVar from typing_extensions import ParamSpec import jax @@ -196,18 +196,22 @@ def __get__(self, instance, owner): @overload def filter_jit( - *, donate: str = "none" + *, donate: Literal["all", "warn", "none"] = "none" ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... @overload -def filter_jit(fun: Callable[_P, _T], *, donate: str = "none") -> Callable[_P, _T]: +def filter_jit( + fun: Callable[_P, _T], *, donate: Literal["all", "warn", "none"] = "none" +) -> Callable[_P, _T]: ... @doc_remove_args("jitkwargs") -def filter_jit(fun=sentinel, *, donate: str = "none", **jitkwargs): +def filter_jit( + fun=sentinel, *, donate: Literal["all", "warn", "none"] = "none", **jitkwargs +): """An easier-to-use version of `jax.jit`. All JAX and NumPy arrays are traced, and all other types are held static. diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index 0f5a4408..ce207b01 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -3,7 +3,7 @@ import inspect import warnings from collections.abc import Callable, Hashable -from typing import Any, Optional, overload, Union +from typing import Any, Literal, Optional, overload, Union import jax import jax._src.traceback_util as traceback_util @@ -572,7 +572,7 @@ def filter_pmap( out_axes: PyTree[AxisSpec] = if_array(0), axis_name: Hashable = None, axis_size: Optional[int] = None, - donate: str = "none", + donate: Literal["all", "warn", "none"] = "none", ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ... @@ -585,7 +585,7 @@ def filter_pmap( out_axes: PyTree[AxisSpec] = if_array(0), axis_name: Hashable = None, axis_size: Optional[int] = None, - donate: str = "none", + donate: Literal["all", "warn", "none"] = "none", ) -> Callable[..., Any]: ... @@ -598,7 +598,7 @@ def filter_pmap( out_axes: PyTree[AxisSpec] = if_array(0), axis_name: Hashable = None, axis_size: Optional[int] = None, - donate: str = "none", + donate: Literal["all", "warn", "none"] = "none", **pmapkwargs, ): """Parallelises a function. By default, all JAX/NumPy arrays are parallelised down @@ -635,10 +635,10 @@ def filter_pmap( it can be deduced by looking at the argument shapes. - `donate` indicates whether the buffers of JAX arrays are donated or not, it should either be: - - `'all'`: the default, donate all arrays and suppress all warnings about + - `'all'`: donate all arrays and suppress all warnings about unused buffers; - `'warn'`: as above, but don't suppress unused buffer warnings; - - `'none'`: disables buffer donation. + - `'none'`: the default, disables buffer donation. **Returns:** @@ -698,11 +698,18 @@ def g(x, y): "'donate_argnums'" ) - if donate not in {"arrays", "warn", "none"}: + if donate == "arrays": + warnings.warn( + "The `donate='arrays'` option to `filter_pmap` has been renamed to " + "`donate='all'`", + DeprecationWarning, + ) + donate == "all" + if donate not in {"all", "warn", "none"}: raise ValueError( - "`filter_jit(..., donate=...)` must be one of 'arrays', 'warn', or 'none'" + "`filter_jit(..., donate=...)` must be one of 'all', 'warn', or 'none'" ) - filter_warning = True if donate == "arrays" else False + filter_warning = True if donate == "all" else False if donate != "none": pmapkwargs["donate_argnums"] = (0,)