diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index a915335b..1518d689 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1114,41 +1114,6 @@ def __getattr__(self, attr): ) -def create_flag_dict(da) -> Mapping[Hashable, FlagParam]: - """ - Return possible flag meanings and associated bitmask/values. - - The mapping values are a tuple containing a bitmask and a value. Either - can be None. - If only a bitmask: Independent flags. - If only a value: Mutually exclusive flags. - If both: Mix of independent and mutually exclusive flags. - """ - if not da.cf.is_flag_variable: - raise ValueError( - "Comparisons are only supported for DataArrays that represent " - "CF flag variables. .attrs must contain 'flag_meanings' and " - "'flag_values' or 'flag_masks'." - ) - - flag_meanings = da.attrs["flag_meanings"].split(" ") - n_flag = len(flag_meanings) - - flag_values = da.attrs.get("flag_values", [None] * n_flag) - flag_masks = da.attrs.get("flag_masks", [None] * n_flag) - - if not (n_flag == len(flag_values) == len(flag_masks)): - raise ValueError( - "Not as many flag meanings as values or masks. " - "Please check the flag_meanings, flag_values, flag_masks attributes " - ) - - flag_params = tuple( - FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values) - ) - return dict(zip(flag_meanings, flag_params)) - - class CFAccessor: """ Common Dataset and DataArray accessor functionality. @@ -1157,23 +1122,64 @@ class CFAccessor: def __init__(self, obj): self._obj = obj self._all_cell_measures = None + self._flag_dict: Mapping[Hashable, FlagParam] | None = None def __setstate__(self, d): self.__dict__ = d - def _assert_valid_other_comparison(self, other): - # TODO cache this property - flag_dict = create_flag_dict(self._obj) - if other not in flag_dict: + @property + def flag_dict(self) -> Mapping[Hashable, FlagParam]: + """ + Return possible flag meanings and associated bitmask/values. + + The mapping values are a tuple containing a bitmask and a value. Either + can be None. + If only a bitmask: Independent flags. + If only a value: Mutually exclusive flags. + If both: Mix of independent and mutually exclusive flags. + """ + if self._flag_dict is not None: + return self._flag_dict + + da = self._obj + + if not da.cf.is_flag_variable: raise ValueError( - f"Did not find flag value meaning [{other}] in known flag meanings: [{flag_dict.keys()!r}]" + "Comparisons are only supported for DataArrays that represent " + "CF flag variables. .attrs must contain 'flag_meanings' and " + "'flag_values' or 'flag_masks'." ) - if flag_dict[other].flag_mask is not None: + + flag_meanings = da.attrs["flag_meanings"].split(" ") + n_flag = len(flag_meanings) + + flag_values = da.attrs.get("flag_values", [None] * n_flag) + flag_masks = da.attrs.get("flag_masks", [None] * n_flag) + + if not (n_flag == len(flag_values) == len(flag_masks)): + raise ValueError( + "Not as many flag meanings as values or masks. " + "Please check the flag_meanings, flag_values, flag_masks attributes " + ) + + flag_params = tuple( + FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values) + ) + return dict(zip(flag_meanings, flag_params)) + + def _assert_valid_other_comparison( + self, other: Hashable + ) -> Mapping[Hashable, FlagParam]: + if other not in self.flag_dict: + raise ValueError( + f"Did not find flag value meaning [{other}] in known flag meanings: [{self.flag_dict.keys()!r}]" + ) + if self.flag_dict[other].flag_mask is not None: raise NotImplementedError( "Only equals and not-equals comparisons with flag masks are supported." " Please open an issue." ) - return flag_dict + return self.flag_dict def __eq__(self, other) -> DataArray: # type: ignore[override] """ @@ -1320,15 +1326,13 @@ def isin(self, test_elements) -> DataArray: raise ValueError( ".cf.isin is only supported on DataArrays that contain CF flag attributes." ) - # TODO cache this property - flag_dict = create_flag_dict(self._obj) mapped_test_elements = [] for elem in test_elements: - if elem not in flag_dict: + if elem not in self.flag_dict: raise ValueError( - f"Did not find flag value meaning [{elem}] in known flag meanings: [{flag_dict.keys()!r}]" + f"Did not find flag value meaning [{elem}] in known flag meanings: [{self.flag_dict.keys()!r}]" ) - mapped_test_elements.append(flag_dict[elem].flag_value) + mapped_test_elements.append(self.flag_dict[elem].flag_value) return self._obj.isin(mapped_test_elements) def _drop_missing_variables(self, variables: list[Hashable]) -> list[Hashable]: @@ -2828,6 +2832,11 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None): @xr.register_dataarray_accessor("cf") class CFDataArrayAccessor(CFAccessor): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._flags: Dataset | None = None + @property def formula_terms(self) -> dict[str, str]: # numpydoc ignore=SS06 """ @@ -2973,6 +2982,8 @@ def flags(self) -> Dataset: """ Dataset containing boolean masks of available flags. """ + if self._flags is not None: + return self._flags return self._extract_flags() def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset: @@ -2982,48 +2993,47 @@ def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset: Parameters ---------- flags: Sequence[str] - Flags to extract. If empty (string or list), return all flags in - `flag_meanings`. + Flags to extract. If None, return all flags in `flag_meanings`. """ - # TODO cache this property - flag_dict = create_flag_dict(self._obj) - + flag_dict = self.flag_dict if flags is None: - flags = tuple(flag_dict.keys()) + flags = list(self.flag_dict.keys()) + else: + for flag in flags: + if flag not in self.flag_dict: + raise ValueError( + f"Did not find flag meaning [{flag}] in known flag meanings:" + f" [{self.flag_dict.keys()!r}]" + ) + flag_dict = {f: flag_dict[f] for f in flags} + + # Check if we are in simplified cases + all_mutually_exclusive = any(f.flag_mask is None for f in flag_dict.values()) + all_indep = any(f.flag_value is None for f in flag_dict.values()) out = {} # Output arrays - masks = [] # Bitmasks and values for asked flags - values = [] - flags_reduced = [] # Flags left after removing mutually excl. flags - for flag in flags: - if flag not in flag_dict: - raise ValueError( - f"Did not find flag value meaning [{flag}] in known flag meanings:" - f" [{flag_dict.keys()!r}]" - ) - mask, value = flag_dict[flag] - if mask is None: - out[flag] = self._obj == value + if all_mutually_exclusive: + for flag, params in flag_dict.items(): + out[flag] = self._obj == params.flag_value + return Dataset(out) + + # We cast both masks and flag variable as integers to make the + # bitwise comparison. + # TODO We could probably restrict the integer size + bit_mask = DataArray( + [f.flag_mask for f in flag_dict.values()], dims=["_mask"] + ).astype("i") + x = self._obj.astype("i") + + bit_comp = x & bit_mask + + for i, (flag, params) in enumerate(flag_dict.items()): + bit = bit_comp.isel(_mask=i) + if all_indep: + out[flag] = bit.astype(bool) else: - masks.append(mask) - values.append(value) - flags_reduced.append(flag) - - if len(masks) > 0: # If independant masks are left - # We cast both masks and flag variable as integers to make the - # bitwise comparison. We could probably restrict the integer size - # but it's difficult to make it safely for mixed type flags. - bit_mask = DataArray(masks, dims=["_mask"]).astype("i") - x = self._obj.astype("i") - bit_comp = x & bit_mask - - for i, (flag, value) in enumerate(zip(flags_reduced, values)): - bit = bit_comp.isel(_mask=i) - if value is not None: - out[flag] = bit == value - else: - out[flag] = bit.astype(bool) + out[flag] = bit == params.flag_value return Dataset(out) diff --git a/cf_xarray/formatting.py b/cf_xarray/formatting.py index 4fe58ccb..4c44ce3f 100644 --- a/cf_xarray/formatting.py +++ b/cf_xarray/formatting.py @@ -208,10 +208,8 @@ def find_set_bits(mask, value, repeated_masks, bit_length): def _format_flags(accessor, rich): - from .accessor import create_flag_dict - try: - flag_dict = create_flag_dict(accessor._obj) + flag_dict = accessor.flag_dict except ValueError: return _print_rows( "Flag Meanings", ["Invalid Mapping. Check attributes."], rich