From c2d6689238cb11c73c803768188cd40281cbadde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Ha=C3=ABck?= <52320542+Descanonge@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:25:15 +0200 Subject: [PATCH 1/5] Cache flag_dict --- cf_xarray/accessor.py | 107 +++++++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 54 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index a915335b..0ea17643 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1114,41 +1114,6 @@ def __getattr__(self, attr): ) -def create_flag_dict(da) -> Mapping[Hashable, FlagParam]: - """ - Return possible flag meanings and associated bitmask/values. - - The mapping values are a tuple containing a bitmask and a value. Either - can be None. - If only a bitmask: Independent flags. - If only a value: Mutually exclusive flags. - If both: Mix of independent and mutually exclusive flags. - """ - if not da.cf.is_flag_variable: - raise ValueError( - "Comparisons are only supported for DataArrays that represent " - "CF flag variables. .attrs must contain 'flag_meanings' and " - "'flag_values' or 'flag_masks'." - ) - - flag_meanings = da.attrs["flag_meanings"].split(" ") - n_flag = len(flag_meanings) - - flag_values = da.attrs.get("flag_values", [None] * n_flag) - flag_masks = da.attrs.get("flag_masks", [None] * n_flag) - - if not (n_flag == len(flag_values) == len(flag_masks)): - raise ValueError( - "Not as many flag meanings as values or masks. " - "Please check the flag_meanings, flag_values, flag_masks attributes " - ) - - flag_params = tuple( - FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values) - ) - return dict(zip(flag_meanings, flag_params)) - - class CFAccessor: """ Common Dataset and DataArray accessor functionality. @@ -1157,23 +1122,62 @@ class CFAccessor: def __init__(self, obj): self._obj = obj self._all_cell_measures = None + self._flag_dict: Mapping[Hashable, FlagParam] | None = None def __setstate__(self, d): self.__dict__ = d - def _assert_valid_other_comparison(self, other): - # TODO cache this property - flag_dict = create_flag_dict(self._obj) - if other not in flag_dict: + @property + def flag_dict(self) -> Mapping[Hashable, FlagParam]: + """ + Return possible flag meanings and associated bitmask/values. + + The mapping values are a tuple containing a bitmask and a value. Either + can be None. + If only a bitmask: Independent flags. + If only a value: Mutually exclusive flags. + If both: Mix of independent and mutually exclusive flags. + """ + if self._flag_dict is not None: + return self._flag_dict + + da = self._obj + + if not da.cf.is_flag_variable: raise ValueError( - f"Did not find flag value meaning [{other}] in known flag meanings: [{flag_dict.keys()!r}]" + "Comparisons are only supported for DataArrays that represent " + "CF flag variables. .attrs must contain 'flag_meanings' and " + "'flag_values' or 'flag_masks'." ) - if flag_dict[other].flag_mask is not None: + + flag_meanings = da.attrs["flag_meanings"].split(" ") + n_flag = len(flag_meanings) + + flag_values = da.attrs.get("flag_values", [None] * n_flag) + flag_masks = da.attrs.get("flag_masks", [None] * n_flag) + + if not (n_flag == len(flag_values) == len(flag_masks)): + raise ValueError( + "Not as many flag meanings as values or masks. " + "Please check the flag_meanings, flag_values, flag_masks attributes " + ) + + flag_params = tuple( + FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values) + ) + return dict(zip(flag_meanings, flag_params)) + + def _assert_valid_other_comparison(self, other: Hashable) -> Mapping[Hashable, FlagParam]: + if other not in self.flag_dict: + raise ValueError( + f"Did not find flag value meaning [{other}] in known flag meanings: [{self.flag_dict.keys()!r}]" + ) + if self.flag_dict[other].flag_mask is not None: raise NotImplementedError( "Only equals and not-equals comparisons with flag masks are supported." " Please open an issue." ) - return flag_dict + return self.flag_dict def __eq__(self, other) -> DataArray: # type: ignore[override] """ @@ -1320,15 +1324,13 @@ def isin(self, test_elements) -> DataArray: raise ValueError( ".cf.isin is only supported on DataArrays that contain CF flag attributes." ) - # TODO cache this property - flag_dict = create_flag_dict(self._obj) mapped_test_elements = [] for elem in test_elements: - if elem not in flag_dict: + if elem not in self.flag_dict: raise ValueError( - f"Did not find flag value meaning [{elem}] in known flag meanings: [{flag_dict.keys()!r}]" + f"Did not find flag value meaning [{elem}] in known flag meanings: [{self.flag_dict.keys()!r}]" ) - mapped_test_elements.append(flag_dict[elem].flag_value) + mapped_test_elements.append(self.flag_dict[elem].flag_value) return self._obj.isin(mapped_test_elements) def _drop_missing_variables(self, variables: list[Hashable]) -> list[Hashable]: @@ -2985,11 +2987,8 @@ def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset: Flags to extract. If empty (string or list), return all flags in `flag_meanings`. """ - # TODO cache this property - flag_dict = create_flag_dict(self._obj) - if flags is None: - flags = tuple(flag_dict.keys()) + flags = tuple(self.flag_dict.keys()) out = {} # Output arrays @@ -2997,12 +2996,12 @@ def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset: values = [] flags_reduced = [] # Flags left after removing mutually excl. flags for flag in flags: - if flag not in flag_dict: + if flag not in self.flag_dict: raise ValueError( f"Did not find flag value meaning [{flag}] in known flag meanings:" - f" [{flag_dict.keys()!r}]" + f" [{self.flag_dict.keys()!r}]" ) - mask, value = flag_dict[flag] + mask, value = self.flag_dict[flag] if mask is None: out[flag] = self._obj == value else: From f5b259a2d6b448182a62804a00ee1e36768f6e3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Ha=C3=ABck?= <52320542+Descanonge@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:53:35 +0200 Subject: [PATCH 2/5] Add return type --- cf_xarray/accessor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 0ea17643..125f5c84 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1167,7 +1167,9 @@ def flag_dict(self) -> Mapping[Hashable, FlagParam]: ) return dict(zip(flag_meanings, flag_params)) - def _assert_valid_other_comparison(self, other: Hashable) -> Mapping[Hashable, FlagParam]: + def _assert_valid_other_comparison( + self, other: Hashable + ) -> Mapping[Hashable, FlagParam]: if other not in self.flag_dict: raise ValueError( f"Did not find flag value meaning [{other}] in known flag meanings: [{self.flag_dict.keys()!r}]" From 0fd6cd953e1247a507885556a654f472fb839017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Ha=C3=ABck?= <52320542+Descanonge@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:54:52 +0200 Subject: [PATCH 3/5] Cache flags dataset --- cf_xarray/accessor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 125f5c84..9acabe2f 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2832,6 +2832,11 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): @xr.register_dataarray_accessor("cf") class CFDataArrayAccessor(CFAccessor): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._flags: Dataset | None = None + @property def formula_terms(self) -> dict[str, str]: # numpydoc ignore=SS06 """ @@ -2977,6 +2982,8 @@ def flags(self) -> Dataset: """ Dataset containing boolean masks of available flags. """ + if self._flags is not None: + return self._flags return self._extract_flags() def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset: From 7e584e7e129c7520f7dc91365a80fbc313772c68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Ha=C3=ABck?= <52320542+Descanonge@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:56:35 +0200 Subject: [PATCH 4/5] Cleaner implementation of flag extraction --- cf_xarray/accessor.py | 68 ++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 9acabe2f..1518d689 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -2993,45 +2993,47 @@ def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset: Parameters ---------- flags: Sequence[str] - Flags to extract. If empty (string or list), return all flags in - `flag_meanings`. + Flags to extract. If None, return all flags in `flag_meanings`. """ + flag_dict = self.flag_dict if flags is None: - flags = tuple(self.flag_dict.keys()) + flags = list(self.flag_dict.keys()) + else: + for flag in flags: + if flag not in self.flag_dict: + raise ValueError( + f"Did not find flag meaning [{flag}] in known flag meanings:" + f" [{self.flag_dict.keys()!r}]" + ) + flag_dict = {f: flag_dict[f] for f in flags} + + # Check if we are in simplified cases + all_mutually_exclusive = any(f.flag_mask is None for f in flag_dict.values()) + all_indep = any(f.flag_value is None for f in flag_dict.values()) out = {} # Output arrays - masks = [] # Bitmasks and values for asked flags - values = [] - flags_reduced = [] # Flags left after removing mutually excl. flags - for flag in flags: - if flag not in self.flag_dict: - raise ValueError( - f"Did not find flag value meaning [{flag}] in known flag meanings:" - f" [{self.flag_dict.keys()!r}]" - ) - mask, value = self.flag_dict[flag] - if mask is None: - out[flag] = self._obj == value + if all_mutually_exclusive: + for flag, params in flag_dict.items(): + out[flag] = self._obj == params.flag_value + return Dataset(out) + + # We cast both masks and flag variable as integers to make the + # bitwise comparison. + # TODO We could probably restrict the integer size + bit_mask = DataArray( + [f.flag_mask for f in flag_dict.values()], dims=["_mask"] + ).astype("i") + x = self._obj.astype("i") + + bit_comp = x & bit_mask + + for i, (flag, params) in enumerate(flag_dict.items()): + bit = bit_comp.isel(_mask=i) + if all_indep: + out[flag] = bit.astype(bool) else: - masks.append(mask) - values.append(value) - flags_reduced.append(flag) - - if len(masks) > 0: # If independant masks are left - # We cast both masks and flag variable as integers to make the - # bitwise comparison. We could probably restrict the integer size - # but it's difficult to make it safely for mixed type flags. - bit_mask = DataArray(masks, dims=["_mask"]).astype("i") - x = self._obj.astype("i") - bit_comp = x & bit_mask - - for i, (flag, value) in enumerate(zip(flags_reduced, values)): - bit = bit_comp.isel(_mask=i) - if value is not None: - out[flag] = bit == value - else: - out[flag] = bit.astype(bool) + out[flag] = bit == params.flag_value return Dataset(out) From 38cd65de258e73cbba94ae984882c3de155b3e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Ha=C3=ABck?= <52320542+Descanonge@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:27:27 +0200 Subject: [PATCH 5/5] Fix formatting --- cf_xarray/formatting.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cf_xarray/formatting.py b/cf_xarray/formatting.py index 4fe58ccb..4c44ce3f 100644 --- a/cf_xarray/formatting.py +++ b/cf_xarray/formatting.py @@ -208,10 +208,8 @@ def find_set_bits(mask, value, repeated_masks, bit_length): def _format_flags(accessor, rich): - from .accessor import create_flag_dict - try: - flag_dict = create_flag_dict(accessor._obj) + flag_dict = accessor.flag_dict except ValueError: return _print_rows( "Flag Meanings", ["Invalid Mapping. Check attributes."], rich