diff --git a/src/arviz_base/utils.py b/src/arviz_base/utils.py index 97019b1..416651f 100644 --- a/src/arviz_base/utils.py +++ b/src/arviz_base/utils.py @@ -10,7 +10,7 @@ def _check_tilde_start(x): return bool(isinstance(x, str) and x.startswith("~")) -def _var_names(var_names, data, filter_vars=None): +def _var_names(var_names, data, filter_vars=None, check_if_present=True): """Handle var_names input across arviz. Parameters @@ -22,6 +22,9 @@ def _var_names(var_names, data, filter_vars=None): interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. A la `pandas.filter`. + check_if_present : bool, optional + If True (default), raise an error if any of the var_names is not present in + the data. If False, ignore missing var_names. Returns ------- @@ -52,14 +55,20 @@ def _var_names(var_names, data, filter_vars=None): ) try: - var_names = _subset_list(var_names, all_vars, filter_items=filter_vars, warn=False) + var_names = _subset_list( + var_names, + all_vars, + filter_items=filter_vars, + warn=False, + check_if_present=check_if_present, + ) except KeyError as err: msg = " ".join(("var names:", f"{err}", "in dataset")) raise KeyError(msg) from err return var_names -def _subset_list(subset, whole_list, filter_items=None, warn=True): +def _subset_list(subset, whole_list, filter_items=None, warn=True, check_if_present=True): """Handle list subsetting (var_names, groups...) across arviz. Parameters @@ -125,7 +134,7 @@ def _subset_list(subset, whole_list, filter_items=None, warn=True): subset = [item for item in whole_list for name in subset if re.search(name, item)] existing_items = np.isin(subset, whole_list) - if not np.all(existing_items): + if check_if_present and not np.all(existing_items): raise KeyError(f"{np.array(subset)[~existing_items]} are not present") return subset diff --git a/src/arviz_base/utils.pyi b/src/arviz_base/utils.pyi index a7f2b29..df8da80 100644 --- a/src/arviz_base/utils.pyi +++ b/src/arviz_base/utils.pyi @@ -15,12 +15,14 @@ def _var_names( var_names: str | list | None, data: xarray.Dataset | Sequence[xarray.Dataset], filter_vars: Literal[None, "like", "regex"] | None = ..., + check_if_present: bool = ..., ) -> list | None: ... def _subset_list( subset: str, whole_list: list, filter_items: Literal[None, "like", "regex"] | None = ..., warn=..., + check_if_present=..., ) -> list | None: ... def _get_coords( data: xarray.DataArray, coords: dict[Hashable, ArrayLike] diff --git a/tests/test_utils.py b/tests/test_utils.py index a81ec8f..5e168c8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -58,6 +58,14 @@ def test_var_names_key_error(data): _var_names(["theta", "tau", "bad_var_name"], data) +def test_var_names_skip_check_if_present(data): + assert _var_names(["theta", "tau", "bad_var_name"], data, check_if_present=False) == [ + "theta", + "tau", + "bad_var_name", + ] + + @pytest.mark.parametrize( "var_args", [