Skip to content

Commit

Permalink
Fix documentation and acceptable values for filter_pmap's 'donate'
Browse files Browse the repository at this point in the history
  • Loading branch information
colehaus authored and patrick-kidger committed Sep 1, 2023
1 parent a234840 commit 8c39b21
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
12 changes: 8 additions & 4 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import warnings
from collections.abc import Callable
from typing import Any, overload, TypeVar
from typing import Any, Literal, overload, TypeVar
from typing_extensions import ParamSpec

import jax
Expand Down Expand Up @@ -196,18 +196,22 @@ def __get__(self, instance, owner):

@overload
def filter_jit(
*, donate: str = "none"
*, donate: Literal["all", "warn", "none"] = "none"
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
...


@overload
def filter_jit(fun: Callable[_P, _T], *, donate: str = "none") -> Callable[_P, _T]:
def filter_jit(
fun: Callable[_P, _T], *, donate: Literal["all", "warn", "none"] = "none"
) -> Callable[_P, _T]:
...


@doc_remove_args("jitkwargs")
def filter_jit(fun=sentinel, *, donate: str = "none", **jitkwargs):
def filter_jit(
fun=sentinel, *, donate: Literal["all", "warn", "none"] = "none", **jitkwargs
):
"""An easier-to-use version of `jax.jit`. All JAX and NumPy arrays are traced, and
all other types are held static.
Expand Down
25 changes: 16 additions & 9 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import warnings
from collections.abc import Callable, Hashable
from typing import Any, Optional, overload, Union
from typing import Any, Literal, Optional, overload, Union

import jax
import jax._src.traceback_util as traceback_util
Expand Down Expand Up @@ -572,7 +572,7 @@ def filter_pmap(
out_axes: PyTree[AxisSpec] = if_array(0),
axis_name: Hashable = None,
axis_size: Optional[int] = None,
donate: str = "none",
donate: Literal["all", "warn", "none"] = "none",
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
...

Expand All @@ -585,7 +585,7 @@ def filter_pmap(
out_axes: PyTree[AxisSpec] = if_array(0),
axis_name: Hashable = None,
axis_size: Optional[int] = None,
donate: str = "none",
donate: Literal["all", "warn", "none"] = "none",
) -> Callable[..., Any]:
...

Expand All @@ -598,7 +598,7 @@ def filter_pmap(
out_axes: PyTree[AxisSpec] = if_array(0),
axis_name: Hashable = None,
axis_size: Optional[int] = None,
donate: str = "none",
donate: Literal["all", "warn", "none"] = "none",
**pmapkwargs,
):
"""Parallelises a function. By default, all JAX/NumPy arrays are parallelised down
Expand Down Expand Up @@ -635,10 +635,10 @@ def filter_pmap(
it can be deduced by looking at the argument shapes.
- `donate` indicates whether the buffers of JAX arrays are donated or not, it
should either be:
- `'all'`: the default, donate all arrays and suppress all warnings about
- `'all'`: donate all arrays and suppress all warnings about
unused buffers;
- `'warn'`: as above, but don't suppress unused buffer warnings;
- `'none'`: disables buffer donation.
- `'none'`: the default, disables buffer donation.
**Returns:**
Expand Down Expand Up @@ -698,11 +698,18 @@ def g(x, y):
"'donate_argnums'"
)

if donate not in {"arrays", "warn", "none"}:
if donate == "arrays":
warnings.warn(
"The `donate='arrays'` option to `filter_pmap` has been renamed to "
"`donate='all'`",
DeprecationWarning,
)
donate = "all"
if donate not in {"all", "warn", "none"}:
raise ValueError(
"`filter_jit(..., donate=...)` must be one of 'arrays', 'warn', or 'none'"
"`filter_jit(..., donate=...)` must be one of 'all', 'warn', or 'none'"
)
filter_warning = True if donate == "arrays" else False
filter_warning = True if donate == "all" else False
if donate != "none":
pmapkwargs["donate_argnums"] = (0,)

Expand Down

0 comments on commit 8c39b21

Please sign in to comment.