Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/arviz_base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def _check_tilde_start(x):
return bool(isinstance(x, str) and x.startswith("~"))


def _var_names(var_names, data, filter_vars=None):
def _var_names(var_names, data, filter_vars=None, check_if_present=True):
"""Handle var_names input across arviz.

Parameters
Expand All @@ -22,6 +22,9 @@ def _var_names(var_names, data, filter_vars=None):
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
`pandas.filter`.
check_if_present : bool, optional
If True (default), raise an error if any of the var_names is not present in
the data. If False, ignore missing var_names.

Returns
-------
Expand Down Expand Up @@ -52,14 +55,20 @@ def _var_names(var_names, data, filter_vars=None):
)

try:
var_names = _subset_list(var_names, all_vars, filter_items=filter_vars, warn=False)
var_names = _subset_list(
var_names,
all_vars,
filter_items=filter_vars,
warn=False,
check_if_present=check_if_present,
)
except KeyError as err:
msg = " ".join(("var names:", f"{err}", "in dataset"))
raise KeyError(msg) from err
return var_names


def _subset_list(subset, whole_list, filter_items=None, warn=True):
def _subset_list(subset, whole_list, filter_items=None, warn=True, check_if_present=True):
"""Handle list subsetting (var_names, groups...) across arviz.

Parameters
Expand Down Expand Up @@ -125,7 +134,7 @@ def _subset_list(subset, whole_list, filter_items=None, warn=True):
subset = [item for item in whole_list for name in subset if re.search(name, item)]

existing_items = np.isin(subset, whole_list)
if not np.all(existing_items):
if check_if_present and not np.all(existing_items):
raise KeyError(f"{np.array(subset)[~existing_items]} are not present")

return subset
Expand Down
2 changes: 2 additions & 0 deletions src/arviz_base/utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ def _var_names(
var_names: str | list | None,
data: xarray.Dataset | Sequence[xarray.Dataset],
filter_vars: Literal[None, "like", "regex"] | None = ...,
check_if_present: bool = ...,
) -> list | None: ...
def _subset_list(
subset: str,
whole_list: list,
filter_items: Literal[None, "like", "regex"] | None = ...,
warn=...,
check_if_present=...,
) -> list | None: ...
def _get_coords(
data: xarray.DataArray, coords: dict[Hashable, ArrayLike]
Expand Down
8 changes: 8 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def test_var_names_key_error(data):
_var_names(["theta", "tau", "bad_var_name"], data)


def test_var_names_skip_check_if_present(data):
assert _var_names(["theta", "tau", "bad_var_name"], data, check_if_present=False) == [
"theta",
"tau",
"bad_var_name",
]


@pytest.mark.parametrize(
"var_args",
[
Expand Down
Loading