Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix documentation and acceptable values for filter_pmap's 'donate' #468

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import warnings
from collections.abc import Callable
from typing import Any, overload, TypeVar
from typing import Any, Literal, overload, TypeVar
from typing_extensions import ParamSpec

import jax
Expand Down Expand Up @@ -196,18 +196,22 @@ def __get__(self, instance, owner):

@overload
def filter_jit(
*, donate: str = "none"
*, donate: Literal["all", "warn", "none"] = "none"
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
...


@overload
def filter_jit(fun: Callable[_P, _T], *, donate: str = "none") -> Callable[_P, _T]:
def filter_jit(
fun: Callable[_P, _T], *, donate: Literal["all", "warn", "none"] = "none"
) -> Callable[_P, _T]:
...


@doc_remove_args("jitkwargs")
def filter_jit(fun=sentinel, *, donate: str = "none", **jitkwargs):
def filter_jit(
fun=sentinel, *, donate: Literal["all", "warn", "none"] = "none", **jitkwargs
):
"""An easier-to-use version of `jax.jit`. All JAX and NumPy arrays are traced, and
all other types are held static.

Expand Down
25 changes: 16 additions & 9 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import warnings
from collections.abc import Callable, Hashable
from typing import Any, Optional, overload, Union
from typing import Any, Literal, Optional, overload, Union

import jax
import jax._src.traceback_util as traceback_util
Expand Down Expand Up @@ -572,7 +572,7 @@ def filter_pmap(
out_axes: PyTree[AxisSpec] = if_array(0),
axis_name: Hashable = None,
axis_size: Optional[int] = None,
donate: str = "none",
donate: Literal["all", "warn", "none"] = "none",
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
...

Expand All @@ -585,7 +585,7 @@ def filter_pmap(
out_axes: PyTree[AxisSpec] = if_array(0),
axis_name: Hashable = None,
axis_size: Optional[int] = None,
donate: str = "none",
donate: Literal["all", "warn", "none"] = "none",
) -> Callable[..., Any]:
...

Expand All @@ -598,7 +598,7 @@ def filter_pmap(
out_axes: PyTree[AxisSpec] = if_array(0),
axis_name: Hashable = None,
axis_size: Optional[int] = None,
donate: str = "none",
donate: Literal["all", "warn", "none"] = "none",
**pmapkwargs,
):
"""Parallelises a function. By default, all JAX/NumPy arrays are parallelised down
Expand Down Expand Up @@ -635,10 +635,10 @@ def filter_pmap(
it can be deduced by looking at the argument shapes.
- `donate` indicates whether the buffers of JAX arrays are donated or not, it
should either be:
- `'all'`: the default, donate all arrays and suppress all warnings about
- `'all'`: donate all arrays and suppress all warnings about
unused buffers;
- `'warn'`: as above, but don't suppress unused buffer warnings;
- `'none'`: disables buffer donation.
- `'none'`: the default, disables buffer donation.

**Returns:**

Expand Down Expand Up @@ -698,11 +698,18 @@ def g(x, y):
"'donate_argnums'"
)

if donate not in {"arrays", "warn", "none"}:
if donate == "arrays":
warnings.warn(
"The `donate='arrays'` option to `filter_pmap` has been renamed to "
"`donate='all'`",
DeprecationWarning,
)
donate = "all"
if donate not in {"all", "warn", "none"}:
raise ValueError(
"`filter_jit(..., donate=...)` must be one of 'arrays', 'warn', or 'none'"
"`filter_jit(..., donate=...)` must be one of 'all', 'warn', or 'none'"
)
filter_warning = True if donate == "arrays" else False
filter_warning = True if donate == "all" else False
if donate != "none":
pmapkwargs["donate_argnums"] = (0,)

Expand Down